Changed: DB Params
This commit is contained in:
62
templ/runtime/buffer.go
Normal file
62
templ/runtime/buffer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DefaultBufferSize is the default size of buffers. It is set to 4KB by default, which is the
|
||||
// same as the default buffer size of bufio.Writer.
|
||||
var DefaultBufferSize = 4 * 1024 // 4KB
|
||||
|
||||
// Buffer is a wrapper around bufio.Writer that enables flushing and closing of
|
||||
// the underlying writer.
|
||||
type Buffer struct {
|
||||
Underlying io.Writer
|
||||
b *bufio.Writer
|
||||
}
|
||||
|
||||
// Write the contents of p into the buffer.
|
||||
func (b *Buffer) Write(p []byte) (n int, err error) {
|
||||
return b.b.Write(p)
|
||||
}
|
||||
|
||||
// Flush writes any buffered data to the underlying io.Writer and
|
||||
// calls the Flush method of the underlying http.Flusher if it implements it.
|
||||
func (b *Buffer) Flush() error {
|
||||
if err := b.b.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if f, ok := b.Underlying.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the buffer and the underlying io.Writer if it implements io.Closer.
|
||||
func (b *Buffer) Close() error {
|
||||
if c, ok := b.Underlying.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset sets the underlying io.Writer to w and resets the buffer.
|
||||
func (b *Buffer) Reset(w io.Writer) {
|
||||
if b.b == nil {
|
||||
b.b = bufio.NewWriterSize(b, DefaultBufferSize)
|
||||
}
|
||||
b.Underlying = w
|
||||
b.b.Reset(w)
|
||||
}
|
||||
|
||||
// Size returns the size of the underlying buffer in bytes.
|
||||
func (b *Buffer) Size() int {
|
||||
return b.b.Size()
|
||||
}
|
||||
|
||||
// WriteString writes the contents of s into the buffer.
|
||||
func (b *Buffer) WriteString(s string) (n int, err error) {
|
||||
return b.b.WriteString(s)
|
||||
}
|
79
templ/runtime/buffer_test.go
Normal file
79
templ/runtime/buffer_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var wasClosed bool
|
||||
|
||||
type closable struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (c *closable) Close() error {
|
||||
wasClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBuffer(t *testing.T) {
|
||||
underlying := httptest.NewRecorder()
|
||||
w, _ := GetBuffer(&closable{underlying})
|
||||
t.Run("can write to a buffer", func(t *testing.T) {
|
||||
if _, err := w.Write([]byte("A")); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can write a string to a buffer", func(t *testing.T) {
|
||||
if _, err := w.WriteString("A"); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can flush a buffer", func(t *testing.T) {
|
||||
if err := w.Flush(); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can close a buffer", func(t *testing.T) {
|
||||
if err := w.Close(); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !wasClosed {
|
||||
t.Error("expected the underlying writer to be closed")
|
||||
}
|
||||
})
|
||||
t.Run("can get the size of a buffer", func(t *testing.T) {
|
||||
if w.Size() != DefaultBufferSize {
|
||||
t.Errorf("expected %d, got %d", DefaultBufferSize, w.Size())
|
||||
}
|
||||
})
|
||||
t.Run("can reset a buffer", func(t *testing.T) {
|
||||
w.Reset(underlying)
|
||||
})
|
||||
if underlying.Body.String() != "AA" {
|
||||
t.Errorf("expected %q, got %q", "AA", underlying.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
type failStream struct {
|
||||
}
|
||||
|
||||
var errTest = errors.New("test error")
|
||||
|
||||
func (f *failStream) Write(p []byte) (n int, err error) {
|
||||
return 0, errTest
|
||||
}
|
||||
|
||||
func (f *failStream) Close() error {
|
||||
return errTest
|
||||
}
|
||||
|
||||
func TestBufferErrors(t *testing.T) {
|
||||
w, _ := GetBuffer(&failStream{})
|
||||
t.Run("close errors are returned", func(t *testing.T) {
|
||||
if err := w.Close(); err != errTest {
|
||||
t.Errorf("expected %v, got %v", errTest, err)
|
||||
}
|
||||
})
|
||||
}
|
38
templ/runtime/bufferpool.go
Normal file
38
templ/runtime/bufferpool.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// GetBuffer creates and returns a new buffer if the writer is not already a buffer,
|
||||
// or returns the existing buffer if it is.
|
||||
func GetBuffer(w io.Writer) (b *Buffer, existing bool) {
|
||||
if w == nil {
|
||||
return nil, false
|
||||
}
|
||||
b, ok := w.(*Buffer)
|
||||
if ok {
|
||||
return b, true
|
||||
}
|
||||
b = bufferPool.Get().(*Buffer)
|
||||
b.Reset(w)
|
||||
return b, false
|
||||
}
|
||||
|
||||
// ReleaseBuffer flushes the buffer and returns it to the pool.
|
||||
func ReleaseBuffer(w io.Writer) (err error) {
|
||||
b, ok := w.(*Buffer)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
err = b.Flush()
|
||||
bufferPool.Put(b)
|
||||
return err
|
||||
}
|
59
templ/runtime/bufferpool_test.go
Normal file
59
templ/runtime/bufferpool_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBufferPool(t *testing.T) {
|
||||
t.Run("can get a buffer from the pool", func(t *testing.T) {
|
||||
w, existing := GetBuffer(new(bytes.Buffer))
|
||||
if w == nil {
|
||||
t.Error("expected a buffer, got nil")
|
||||
}
|
||||
if existing {
|
||||
t.Error("expected a new buffer, got an existing buffer")
|
||||
}
|
||||
err := ReleaseBuffer(w)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can get an existing buffer from the pool", func(t *testing.T) {
|
||||
w, existing := GetBuffer(new(bytes.Buffer))
|
||||
if w == nil {
|
||||
t.Error("expected a buffer, got nil")
|
||||
}
|
||||
if existing {
|
||||
t.Error("expected a new buffer, got an existing buffer")
|
||||
}
|
||||
|
||||
w, existing = GetBuffer(w)
|
||||
if w == nil {
|
||||
t.Error("expected a buffer, got nil")
|
||||
}
|
||||
if !existing {
|
||||
t.Error("expected an existing buffer, got a new buffer")
|
||||
}
|
||||
|
||||
err := ReleaseBuffer(w)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can release any writer without error", func(t *testing.T) {
|
||||
err := ReleaseBuffer(new(bytes.Buffer))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("attempting to buffer a nil writer returns nil", func(t *testing.T) {
|
||||
w, existing := GetBuffer(nil)
|
||||
if w != nil {
|
||||
t.Error("expected nil, got a buffer")
|
||||
}
|
||||
if existing {
|
||||
t.Error("expected nil, got an existing buffer")
|
||||
}
|
||||
})
|
||||
}
|
8
templ/runtime/builder.go
Normal file
8
templ/runtime/builder.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package runtime
|
||||
|
||||
import "strings"
|
||||
|
||||
// GetBuilder returns a strings.Builder.
|
||||
func GetBuilder() (sb strings.Builder) {
|
||||
return sb
|
||||
}
|
11
templ/runtime/builder_test.go
Normal file
11
templ/runtime/builder_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package runtime
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetBuilder(t *testing.T) {
|
||||
sb := GetBuilder()
|
||||
sb.WriteString("test")
|
||||
if sb.String() != "test" {
|
||||
t.Errorf("expected \"test\", got %q", sb.String())
|
||||
}
|
||||
}
|
21
templ/runtime/runtime.go
Normal file
21
templ/runtime/runtime.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
|
||||
// GeneratedComponentInput is used to avoid generated code needing to import the `context` and `io` packages.
|
||||
type GeneratedComponentInput struct {
|
||||
Context context.Context
|
||||
Writer io.Writer
|
||||
}
|
||||
|
||||
// GeneratedTemplate is used to avoid generated code needing to import the `context` and `io` packages.
|
||||
func GeneratedTemplate(f func(GeneratedComponentInput) error) templ.Component {
|
||||
return templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
|
||||
return f(GeneratedComponentInput{ctx, w})
|
||||
})
|
||||
}
|
22
templ/runtime/runtime_test.go
Normal file
22
templ/runtime/runtime_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeneratedTemplate(t *testing.T) {
|
||||
f := func(input GeneratedComponentInput) error {
|
||||
_, err := input.Writer.Write([]byte("Hello, World!"))
|
||||
return err
|
||||
}
|
||||
sb := new(strings.Builder)
|
||||
err := GeneratedTemplate(f).Render(context.Background(), sb)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if sb.String() != "Hello, World!" {
|
||||
t.Errorf("expected \"Hello, World!\", got %q", sb.String())
|
||||
}
|
||||
}
|
217
templ/runtime/styleattribute.go
Normal file
217
templ/runtime/styleattribute.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"maps"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/safehtml"
|
||||
)
|
||||
|
||||
// SanitizeStyleAttributeValues renders a style attribute value.
|
||||
// The supported types are:
|
||||
// - string
|
||||
// - templ.SafeCSS
|
||||
// - map[string]string
|
||||
// - map[string]templ.SafeCSSProperty
|
||||
// - templ.KeyValue[string, string] - A map of key/values where the key is the CSS property name and the value is the CSS property value.
|
||||
// - templ.KeyValue[string, templ.SafeCSSProperty] - A map of key/values where the key is the CSS property name and the value is the CSS property value.
|
||||
// - templ.KeyValue[string, bool] - The bool determines whether the value should be included.
|
||||
// - templ.KeyValue[templ.SafeCSS, bool] - The bool determines whether the value should be included.
|
||||
// - func() (anyOfTheAboveTypes)
|
||||
// - func() (anyOfTheAboveTypes, error)
|
||||
// - []anyOfTheAboveTypes
|
||||
//
|
||||
// In the above, templ.SafeCSS and templ.SafeCSSProperty are types that are used to indicate that the value is safe to render as CSS without sanitization.
|
||||
// All other types are sanitized before rendering.
|
||||
//
|
||||
// If an error is returned by any function, or a non-nil error is included in the input, the error is returned.
|
||||
func SanitizeStyleAttributeValues(values ...any) (string, error) {
|
||||
if err := getJoinedErrorsFromValues(values...); err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb := new(strings.Builder)
|
||||
for _, v := range values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
if err := sanitizeStyleAttributeValue(sb, v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func sanitizeStyleAttributeValue(sb *strings.Builder, v any) error {
|
||||
// Process concrete types.
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
return processString(sb, v)
|
||||
|
||||
case templ.SafeCSS:
|
||||
return processSafeCSS(sb, v)
|
||||
|
||||
case map[string]string:
|
||||
return processStringMap(sb, v)
|
||||
|
||||
case map[string]templ.SafeCSSProperty:
|
||||
return processSafeCSSPropertyMap(sb, v)
|
||||
|
||||
case templ.KeyValue[string, string]:
|
||||
return processStringKV(sb, v)
|
||||
|
||||
case templ.KeyValue[string, bool]:
|
||||
if v.Value {
|
||||
return processString(sb, v.Key)
|
||||
}
|
||||
return nil
|
||||
|
||||
case templ.KeyValue[templ.SafeCSS, bool]:
|
||||
if v.Value {
|
||||
return processSafeCSS(sb, v.Key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fall back to reflection.
|
||||
|
||||
// Handle functions first using reflection.
|
||||
if handled, err := handleFuncWithReflection(sb, v); handled {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle slices using reflection before concrete types.
|
||||
if handled, err := handleSliceWithReflection(sb, v); handled {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := sb.WriteString(TemplUnsupportedStyleAttributeValue)
|
||||
return err
|
||||
}
|
||||
|
||||
func processSafeCSS(sb *strings.Builder, v templ.SafeCSS) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
sb.WriteString(html.EscapeString(string(v)))
|
||||
if !strings.HasSuffix(string(v), ";") {
|
||||
sb.WriteRune(';')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processString(sb *strings.Builder, v string) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
sanitized := strings.TrimSpace(safehtml.SanitizeStyleValue(v))
|
||||
sb.WriteString(html.EscapeString(sanitized))
|
||||
if !strings.HasSuffix(sanitized, ";") {
|
||||
sb.WriteRune(';')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrInvalidStyleAttributeFunctionSignature = errors.New("invalid function signature, should be in the form func() (string, error)")
|
||||
|
||||
// handleFuncWithReflection handles functions using reflection.
|
||||
func handleFuncWithReflection(sb *strings.Builder, v any) (bool, error) {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Func {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
t := rv.Type()
|
||||
if t.NumIn() != 0 || (t.NumOut() != 1 && t.NumOut() != 2) {
|
||||
return false, ErrInvalidStyleAttributeFunctionSignature
|
||||
}
|
||||
|
||||
// Check the types of the return values
|
||||
if t.NumOut() == 2 {
|
||||
// Ensure the second return value is of type `error`
|
||||
secondReturnType := t.Out(1)
|
||||
if !secondReturnType.Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return false, fmt.Errorf("second return value must be of type error, got %v", secondReturnType)
|
||||
}
|
||||
}
|
||||
|
||||
results := rv.Call(nil)
|
||||
|
||||
if t.NumOut() == 2 {
|
||||
// Check if the second return value is an error
|
||||
if errVal := results[1].Interface(); errVal != nil {
|
||||
if err, ok := errVal.(error); ok && err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, sanitizeStyleAttributeValue(sb, results[0].Interface())
|
||||
}
|
||||
|
||||
// handleSliceWithReflection handles slices using reflection.
|
||||
func handleSliceWithReflection(sb *strings.Builder, v any) (bool, error) {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return false, nil
|
||||
}
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
elem := rv.Index(i).Interface()
|
||||
if err := sanitizeStyleAttributeValue(sb, elem); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// processStringMap processes a map[string]string.
|
||||
func processStringMap(sb *strings.Builder, m map[string]string) error {
|
||||
for _, name := range slices.Sorted(maps.Keys(m)) {
|
||||
name, value := safehtml.SanitizeCSS(name, m[name])
|
||||
sb.WriteString(html.EscapeString(name))
|
||||
sb.WriteRune(':')
|
||||
sb.WriteString(html.EscapeString(value))
|
||||
sb.WriteRune(';')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processSafeCSSPropertyMap processes a map[string]templ.SafeCSSProperty.
|
||||
func processSafeCSSPropertyMap(sb *strings.Builder, m map[string]templ.SafeCSSProperty) error {
|
||||
for _, name := range slices.Sorted(maps.Keys(m)) {
|
||||
sb.WriteString(html.EscapeString(safehtml.SanitizeCSSProperty(name)))
|
||||
sb.WriteRune(':')
|
||||
sb.WriteString(html.EscapeString(string(m[name])))
|
||||
sb.WriteRune(';')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processStringKV processes a templ.KeyValue[string, string].
|
||||
func processStringKV(sb *strings.Builder, kv templ.KeyValue[string, string]) error {
|
||||
name, value := safehtml.SanitizeCSS(kv.Key, kv.Value)
|
||||
sb.WriteString(html.EscapeString(name))
|
||||
sb.WriteRune(':')
|
||||
sb.WriteString(html.EscapeString(value))
|
||||
sb.WriteRune(';')
|
||||
return nil
|
||||
}
|
||||
|
||||
// getJoinedErrorsFromValues collects and joins errors from the input values.
|
||||
func getJoinedErrorsFromValues(values ...any) error {
|
||||
var errs []error
|
||||
for _, v := range values {
|
||||
if err, ok := v.(error); ok {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// TemplUnsupportedStyleAttributeValue is the default value returned for unsupported types.
|
||||
var TemplUnsupportedStyleAttributeValue = "zTemplUnsupportedStyleAttributeValue:Invalid;"
|
333
templ/runtime/styleattribute_test.go
Normal file
333
templ/runtime/styleattribute_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
var (
|
||||
err1 = errors.New("error 1")
|
||||
err2 = errors.New("error 2")
|
||||
)
|
||||
|
||||
func TestSanitizeStyleAttribute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []any
|
||||
expected string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "errors are returned",
|
||||
input: []any{err1},
|
||||
expectedErr: err1,
|
||||
},
|
||||
{
|
||||
name: "multiple errors are joined and returned",
|
||||
input: []any{err1, err2},
|
||||
expectedErr: errors.Join(err1, err2),
|
||||
},
|
||||
{
|
||||
name: "functions that return errors return the error",
|
||||
input: []any{
|
||||
"color:red",
|
||||
func() (string, error) { return "", err1 },
|
||||
},
|
||||
expectedErr: err1,
|
||||
},
|
||||
|
||||
// string
|
||||
{
|
||||
name: "strings: are allowed",
|
||||
input: []any{"color:red;background-color:blue;"},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "strings: have semi-colons appended if missing",
|
||||
input: []any{"color:red;background-color:blue"},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "strings: empty strings are elided",
|
||||
input: []any{""},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "strings: are sanitized",
|
||||
input: []any{"</style><script>alert('xss')</script>"},
|
||||
expected: `\00003C/style>\00003Cscript>alert('xss')\00003C/script>;`,
|
||||
},
|
||||
|
||||
// templ.SafeCSS
|
||||
{
|
||||
name: "SafeCSS: is allowed",
|
||||
input: []any{templ.SafeCSS("color:red;background-color:blue;")},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "SafeCSS: have semi-colons appended if missing",
|
||||
input: []any{templ.SafeCSS("color:red;background-color:blue")},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "SafeCSS: empty strings are elided",
|
||||
input: []any{templ.SafeCSS("")},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "SafeCSS: is escaped, but not sanitized",
|
||||
input: []any{templ.SafeCSS("</style>")},
|
||||
expected: `</style>;`,
|
||||
},
|
||||
|
||||
// map[string]string
|
||||
{
|
||||
name: "map[string]string: is allowed",
|
||||
input: []any{map[string]string{"color": "red", "background-color": "blue"}},
|
||||
expected: "background-color:blue;color:red;",
|
||||
},
|
||||
{
|
||||
name: "map[string]string: keys are sorted",
|
||||
input: []any{map[string]string{"z-index": "1", "color": "red", "background-color": "blue"}},
|
||||
expected: "background-color:blue;color:red;z-index:1;",
|
||||
},
|
||||
{
|
||||
name: "map[string]string: empty names are invalid",
|
||||
input: []any{map[string]string{"": "red", "background-color": "blue"}},
|
||||
expected: "zTemplUnsafeCSSPropertyName:zTemplUnsafeCSSPropertyValue;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "map[string]string: keys and values are sanitized",
|
||||
input: []any{map[string]string{"color": "</style>", "background-color": "blue"}},
|
||||
expected: "background-color:blue;color:zTemplUnsafeCSSPropertyValue;",
|
||||
},
|
||||
|
||||
// map[string]templ.SafeCSSProperty
|
||||
{
|
||||
name: "map[string]templ.SafeCSSProperty: is allowed",
|
||||
input: []any{map[string]templ.SafeCSSProperty{"color": "red", "background-color": "blue"}},
|
||||
expected: "background-color:blue;color:red;",
|
||||
},
|
||||
{
|
||||
name: "map[string]templ.SafeCSSProperty: keys are sorted",
|
||||
input: []any{map[string]templ.SafeCSSProperty{"z-index": "1", "color": "red", "background-color": "blue"}},
|
||||
expected: "background-color:blue;color:red;z-index:1;",
|
||||
},
|
||||
{
|
||||
name: "map[string]templ.SafeCSSProperty: empty names are invalid",
|
||||
input: []any{map[string]templ.SafeCSSProperty{"": "red", "background-color": "blue"}},
|
||||
expected: "zTemplUnsafeCSSPropertyName:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "map[string]templ.SafeCSSProperty: keys are sanitized, but not values",
|
||||
input: []any{map[string]templ.SafeCSSProperty{"color": "</style>", "</style>": "blue"}},
|
||||
expected: "zTemplUnsafeCSSPropertyName:blue;color:</style>;",
|
||||
},
|
||||
|
||||
// templ.KeyValue[string, string]
|
||||
{
|
||||
name: "KeyValue[string, string]: is allowed",
|
||||
input: []any{templ.KV("color", "red"), templ.KV("background-color", "blue")},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, string]: keys and values are sanitized",
|
||||
input: []any{templ.KV("color", "</style>"), templ.KV("</style>", "blue")},
|
||||
expected: "color:zTemplUnsafeCSSPropertyValue;zTemplUnsafeCSSPropertyName:zTemplUnsafeCSSPropertyValue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, string]: empty names are invalid",
|
||||
input: []any{templ.KV("", "red"), templ.KV("background-color", "blue")},
|
||||
expected: "zTemplUnsafeCSSPropertyName:zTemplUnsafeCSSPropertyValue;background-color:blue;",
|
||||
},
|
||||
|
||||
// templ.KeyValue[string, templ.SafeCSSProperty]
|
||||
{
|
||||
name: "KeyValue[string, templ.SafeCSSProperty]: is allowed",
|
||||
input: []any{templ.KV("color", "red"), templ.KV("background-color", "blue")},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, templ.SafeCSSProperty]: keys are sanitized, but not values",
|
||||
input: []any{templ.KV("color", "</style>"), templ.KV("</style>", "blue")},
|
||||
expected: "color:zTemplUnsafeCSSPropertyValue;zTemplUnsafeCSSPropertyName:zTemplUnsafeCSSPropertyValue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, templ.SafeCSSProperty]: empty names are invalid",
|
||||
input: []any{templ.KV("", "red"), templ.KV("background-color", "blue")},
|
||||
expected: "zTemplUnsafeCSSPropertyName:zTemplUnsafeCSSPropertyValue;background-color:blue;",
|
||||
},
|
||||
|
||||
// templ.KeyValue[string, bool]
|
||||
{
|
||||
name: "KeyValue[string, bool]: is allowed",
|
||||
input: []any{templ.KV("color:red", true), templ.KV("background-color:blue", true), templ.KV("color:blue", false)},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, bool]: false values are elided",
|
||||
input: []any{templ.KV("color:red", false), templ.KV("background-color:blue", true)},
|
||||
expected: "background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[string, bool]: keys are sanitized as per strings",
|
||||
input: []any{templ.KV("</style>", true), templ.KV("background-color:blue", true)},
|
||||
expected: "\\00003C/style>;background-color:blue;",
|
||||
},
|
||||
|
||||
// templ.KeyValue[templ.SafeCSS, bool]
|
||||
{
|
||||
name: "KeyValue[templ.SafeCSS, bool]: is allowed",
|
||||
input: []any{templ.KV(templ.SafeCSS("color:red"), true), templ.KV(templ.SafeCSS("background-color:blue"), true), templ.KV(templ.SafeCSS("color:blue"), false)},
|
||||
expected: "color:red;background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[templ.SafeCSS, bool]: false values are elided",
|
||||
input: []any{templ.KV(templ.SafeCSS("color:red"), false), templ.KV(templ.SafeCSS("background-color:blue"), true)},
|
||||
expected: "background-color:blue;",
|
||||
},
|
||||
{
|
||||
name: "KeyValue[templ.SafeCSS, bool]: keys are not sanitized",
|
||||
input: []any{templ.KV(templ.SafeCSS("</style>"), true), templ.KV(templ.SafeCSS("background-color:blue"), true)},
|
||||
expected: "</style>;background-color:blue;",
|
||||
},
|
||||
|
||||
// Functions.
|
||||
{
|
||||
name: "func: string",
|
||||
input: []any{
|
||||
func() string { return "color:red" },
|
||||
},
|
||||
expected: `color:red;`,
|
||||
},
|
||||
{
|
||||
name: "func: string, error - success",
|
||||
input: []any{
|
||||
func() (string, error) { return "color:blue", nil },
|
||||
},
|
||||
expected: `color:blue;`,
|
||||
},
|
||||
{
|
||||
name: "func: string, error - error",
|
||||
input: []any{
|
||||
func() (string, error) { return "", err1 },
|
||||
},
|
||||
expectedErr: err1,
|
||||
},
|
||||
{
|
||||
name: "func: invalid signature",
|
||||
input: []any{
|
||||
func() (string, string) { return "color:blue", "color:blue" },
|
||||
},
|
||||
expected: TemplUnsupportedStyleAttributeValue,
|
||||
},
|
||||
{
|
||||
name: "func: only one or two return values are allowed",
|
||||
input: []any{
|
||||
func() (string, string, string) { return "color:blue", "color:blue", "color:blue" },
|
||||
},
|
||||
expected: TemplUnsupportedStyleAttributeValue,
|
||||
},
|
||||
|
||||
// Slices.
|
||||
{
|
||||
name: "slices: mixed types are allowed",
|
||||
input: []any{
|
||||
[]any{
|
||||
"color:red",
|
||||
templ.KV("text-decoration: underline", true),
|
||||
map[string]string{"background": "blue"},
|
||||
},
|
||||
},
|
||||
expected: `color:red;text-decoration: underline;background:blue;`,
|
||||
},
|
||||
{
|
||||
name: "slices: nested slices are allowed",
|
||||
input: []any{
|
||||
[]any{
|
||||
[]string{"color:red", "font-size:12px"},
|
||||
[]templ.SafeCSS{"margin:0", "padding:0"},
|
||||
},
|
||||
},
|
||||
expected: `color:red;font-size:12px;margin:0;padding:0;`,
|
||||
},
|
||||
|
||||
// Edge cases.
|
||||
{
|
||||
name: "edge: nil input",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "edge: empty input",
|
||||
input: []any{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "edge: unsupported type",
|
||||
input: []any{42},
|
||||
expected: TemplUnsupportedStyleAttributeValue,
|
||||
},
|
||||
{
|
||||
name: "edge: nil input",
|
||||
input: []any{nil},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual, err := SanitizeStyleAttributeValues(tt.input...)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got nil")
|
||||
}
|
||||
if diff := cmp.Diff(tt.expectedErr.Error(), err.Error()); diff != "" {
|
||||
t.Errorf("error mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, actual); diff != "" {
|
||||
t.Errorf("result mismatch (-want +got):\n%s", diff)
|
||||
t.Logf("Actual result: %q", actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkSanitizeAttributeValues(b *testing.B, input ...any) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
if _, err := SanitizeStyleAttributeValues(input...); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSanitizeAttributeValuesErr(b *testing.B) { benchmarkSanitizeAttributeValues(b, err1) }
|
||||
func BenchmarkSanitizeAttributeValuesString(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, "color:red;background-color:blue;")
|
||||
}
|
||||
func BenchmarkSanitizeAttributeValuesStringSanitized(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, "</style><script>alert('xss')</script>")
|
||||
}
|
||||
func BenchmarkSanitizeAttributeValuesSafeCSS(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, templ.SafeCSS("color:red;background-color:blue;"))
|
||||
}
|
||||
func BenchmarkSanitizeAttributeValuesMap(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, map[string]string{"color": "red", "background-color": "blue"})
|
||||
}
|
||||
func BenchmarkSanitizeAttributeValuesKV(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, templ.KV("color", "red"), templ.KV("background-color", "blue"))
|
||||
}
|
||||
func BenchmarkSanitizeAttributeValuesFunc(b *testing.B) {
|
||||
benchmarkSanitizeAttributeValues(b, func() string { return "color:red" })
|
||||
}
|
104
templ/runtime/watchmode.go
Normal file
104
templ/runtime/watchmode.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var developmentMode = os.Getenv("TEMPL_DEV_MODE") == "true"
|
||||
|
||||
// WriteString writes the string to the writer. If development mode is enabled
|
||||
// s is replaced with the string at the index in the _templ.txt file.
|
||||
func WriteString(w io.Writer, index int, s string) (err error) {
|
||||
if developmentMode {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
if !strings.HasSuffix(path, "_templ.go") {
|
||||
return errors.New("templ: attempt to use WriteString from a non templ file")
|
||||
}
|
||||
txtFilePath := strings.Replace(path, "_templ.go", "_templ.txt", 1)
|
||||
|
||||
literals, err := getWatchedStrings(txtFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("templ: failed to cache strings: %w", err)
|
||||
}
|
||||
|
||||
if index > len(literals) {
|
||||
return fmt.Errorf("templ: failed to find line %d in %s", index, txtFilePath)
|
||||
}
|
||||
|
||||
s, err = strconv.Unquote(`"` + literals[index-1] + `"`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err = io.WriteString(w, s)
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
watchModeCache = map[string]watchState{}
|
||||
watchStateMutex sync.Mutex
|
||||
)
|
||||
|
||||
type watchState struct {
|
||||
modTime time.Time
|
||||
strings []string
|
||||
}
|
||||
|
||||
func getWatchedStrings(txtFilePath string) ([]string, error) {
|
||||
watchStateMutex.Lock()
|
||||
defer watchStateMutex.Unlock()
|
||||
|
||||
state, cached := watchModeCache[txtFilePath]
|
||||
if !cached {
|
||||
return cacheStrings(txtFilePath)
|
||||
}
|
||||
|
||||
if time.Since(state.modTime) < time.Millisecond*100 {
|
||||
return state.strings, nil
|
||||
}
|
||||
|
||||
info, err := os.Stat(txtFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err)
|
||||
}
|
||||
|
||||
if !info.ModTime().After(state.modTime) {
|
||||
return state.strings, nil
|
||||
}
|
||||
|
||||
return cacheStrings(txtFilePath)
|
||||
}
|
||||
|
||||
func cacheStrings(txtFilePath string) ([]string, error) {
|
||||
txtFile, err := os.Open(txtFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templ: failed to open %s: %w", txtFilePath, err)
|
||||
}
|
||||
defer txtFile.Close()
|
||||
|
||||
info, err := txtFile.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err)
|
||||
}
|
||||
|
||||
all, err := io.ReadAll(txtFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("templ: failed to read %s: %w", txtFilePath, err)
|
||||
}
|
||||
|
||||
literals := strings.Split(string(all), "\n")
|
||||
watchModeCache[txtFilePath] = watchState{
|
||||
modTime: info.ModTime(),
|
||||
strings: literals,
|
||||
}
|
||||
|
||||
return literals, nil
|
||||
}
|
Reference in New Issue
Block a user