package runtime import ( "errors" "fmt" "html" "maps" "reflect" "slices" "strings" "github.com/a-h/templ" "github.com/a-h/templ/safehtml" ) // SanitizeStyleAttributeValues renders a style attribute value. // The supported types are: // - string // - templ.SafeCSS // - map[string]string // - map[string]templ.SafeCSSProperty // - templ.KeyValue[string, string] - A map of key/values where the key is the CSS property name and the value is the CSS property value. // - templ.KeyValue[string, templ.SafeCSSProperty] - A map of key/values where the key is the CSS property name and the value is the CSS property value. // - templ.KeyValue[string, bool] - The bool determines whether the value should be included. // - templ.KeyValue[templ.SafeCSS, bool] - The bool determines whether the value should be included. // - func() (anyOfTheAboveTypes) // - func() (anyOfTheAboveTypes, error) // - []anyOfTheAboveTypes // // In the above, templ.SafeCSS and templ.SafeCSSProperty are types that are used to indicate that the value is safe to render as CSS without sanitization. // All other types are sanitized before rendering. // // If an error is returned by any function, or a non-nil error is included in the input, the error is returned. func SanitizeStyleAttributeValues(values ...any) (string, error) { if err := getJoinedErrorsFromValues(values...); err != nil { return "", err } sb := new(strings.Builder) for _, v := range values { if v == nil { continue } if err := sanitizeStyleAttributeValue(sb, v); err != nil { return "", err } } return sb.String(), nil } func sanitizeStyleAttributeValue(sb *strings.Builder, v any) error { // Process concrete types. switch v := v.(type) { case string: return processString(sb, v) case templ.SafeCSS: return processSafeCSS(sb, v) case map[string]string: return processStringMap(sb, v) case map[string]templ.SafeCSSProperty: return processSafeCSSPropertyMap(sb, v) case templ.KeyValue[string, string]: return processStringKV(sb, v) case templ.KeyValue[string, bool]: if v.Value { return processString(sb, v.Key) } return nil case templ.KeyValue[templ.SafeCSS, bool]: if v.Value { return processSafeCSS(sb, v.Key) } return nil } // Fall back to reflection. // Handle functions first using reflection. if handled, err := handleFuncWithReflection(sb, v); handled { return err } // Handle slices using reflection before concrete types. if handled, err := handleSliceWithReflection(sb, v); handled { return err } _, err := sb.WriteString(TemplUnsupportedStyleAttributeValue) return err } func processSafeCSS(sb *strings.Builder, v templ.SafeCSS) error { if v == "" { return nil } sb.WriteString(html.EscapeString(string(v))) if !strings.HasSuffix(string(v), ";") { sb.WriteRune(';') } return nil } func processString(sb *strings.Builder, v string) error { if v == "" { return nil } sanitized := strings.TrimSpace(safehtml.SanitizeStyleValue(v)) sb.WriteString(html.EscapeString(sanitized)) if !strings.HasSuffix(sanitized, ";") { sb.WriteRune(';') } return nil } var ErrInvalidStyleAttributeFunctionSignature = errors.New("invalid function signature, should be in the form func() (string, error)") // handleFuncWithReflection handles functions using reflection. func handleFuncWithReflection(sb *strings.Builder, v any) (bool, error) { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Func { return false, nil } t := rv.Type() if t.NumIn() != 0 || (t.NumOut() != 1 && t.NumOut() != 2) { return false, ErrInvalidStyleAttributeFunctionSignature } // Check the types of the return values if t.NumOut() == 2 { // Ensure the second return value is of type `error` secondReturnType := t.Out(1) if !secondReturnType.Implements(reflect.TypeOf((*error)(nil)).Elem()) { return false, fmt.Errorf("second return value must be of type error, got %v", secondReturnType) } } results := rv.Call(nil) if t.NumOut() == 2 { // Check if the second return value is an error if errVal := results[1].Interface(); errVal != nil { if err, ok := errVal.(error); ok && err != nil { return true, err } } } return true, sanitizeStyleAttributeValue(sb, results[0].Interface()) } // handleSliceWithReflection handles slices using reflection. func handleSliceWithReflection(sb *strings.Builder, v any) (bool, error) { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Slice { return false, nil } for i := 0; i < rv.Len(); i++ { elem := rv.Index(i).Interface() if err := sanitizeStyleAttributeValue(sb, elem); err != nil { return true, err } } return true, nil } // processStringMap processes a map[string]string. func processStringMap(sb *strings.Builder, m map[string]string) error { for _, name := range slices.Sorted(maps.Keys(m)) { name, value := safehtml.SanitizeCSS(name, m[name]) sb.WriteString(html.EscapeString(name)) sb.WriteRune(':') sb.WriteString(html.EscapeString(value)) sb.WriteRune(';') } return nil } // processSafeCSSPropertyMap processes a map[string]templ.SafeCSSProperty. func processSafeCSSPropertyMap(sb *strings.Builder, m map[string]templ.SafeCSSProperty) error { for _, name := range slices.Sorted(maps.Keys(m)) { sb.WriteString(html.EscapeString(safehtml.SanitizeCSSProperty(name))) sb.WriteRune(':') sb.WriteString(html.EscapeString(string(m[name]))) sb.WriteRune(';') } return nil } // processStringKV processes a templ.KeyValue[string, string]. func processStringKV(sb *strings.Builder, kv templ.KeyValue[string, string]) error { name, value := safehtml.SanitizeCSS(kv.Key, kv.Value) sb.WriteString(html.EscapeString(name)) sb.WriteRune(':') sb.WriteString(html.EscapeString(value)) sb.WriteRune(';') return nil } // getJoinedErrorsFromValues collects and joins errors from the input values. func getJoinedErrorsFromValues(values ...any) error { var errs []error for _, v := range values { if err, ok := v.(error); ok { errs = append(errs, err) } } return errors.Join(errs...) } // TemplUnsupportedStyleAttributeValue is the default value returned for unsupported types. var TemplUnsupportedStyleAttributeValue = "zTemplUnsupportedStyleAttributeValue:Invalid;"