Changed: DB Params

This commit is contained in:
2025-03-20 12:35:13 +01:00
parent 8640a12439
commit b71b3d12ca
822 changed files with 134218 additions and 0 deletions

View File

@@ -0,0 +1,403 @@
package generatecmd
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/a-h/templ"
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
"github.com/a-h/templ/cmd/templ/generatecmd/proxy"
"github.com/a-h/templ/cmd/templ/generatecmd/run"
"github.com/a-h/templ/cmd/templ/generatecmd/watcher"
"github.com/a-h/templ/generator"
"github.com/cenkalti/backoff/v4"
"github.com/cli/browser"
"github.com/fsnotify/fsnotify"
)
const defaultWatchPattern = `(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)`
func NewGenerate(log *slog.Logger, args Arguments) (g *Generate, err error) {
g = &Generate{
Log: log,
Args: &args,
}
if g.Args.WorkerCount == 0 {
g.Args.WorkerCount = runtime.NumCPU()
}
if g.Args.WatchPattern == "" {
g.Args.WatchPattern = defaultWatchPattern
}
g.WatchPattern, err = regexp.Compile(g.Args.WatchPattern)
if err != nil {
return nil, fmt.Errorf("failed to compile watch pattern %q: %w", g.Args.WatchPattern, err)
}
return g, nil
}
type Generate struct {
Log *slog.Logger
Args *Arguments
WatchPattern *regexp.Regexp
}
type GenerationEvent struct {
Event fsnotify.Event
Updated bool
GoUpdated bool
TextUpdated bool
}
func (cmd Generate) Run(ctx context.Context) (err error) {
if cmd.Args.NotifyProxy {
return proxy.NotifyProxy(cmd.Args.ProxyBind, cmd.Args.ProxyPort)
}
if cmd.Args.Watch && cmd.Args.FileName != "" {
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
}
writingToWriter := cmd.Args.FileWriter != nil
if cmd.Args.FileName == "" && writingToWriter {
return fmt.Errorf("only a single file can be output to stdout, add the -f flag to specify the file to generate code for")
}
// Default to writing to files.
if cmd.Args.FileWriter == nil {
cmd.Args.FileWriter = FileWriter
}
if cmd.Args.PPROFPort > 0 {
go func() {
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", cmd.Args.PPROFPort), nil)
}()
}
// Use absolute path.
if !path.IsAbs(cmd.Args.Path) {
cmd.Args.Path, err = filepath.Abs(cmd.Args.Path)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}
}
// Configure generator.
var opts []generator.GenerateOpt
if cmd.Args.IncludeVersion {
opts = append(opts, generator.WithVersion(templ.Version()))
}
if cmd.Args.IncludeTimestamp {
opts = append(opts, generator.WithTimestamp(time.Now()))
}
// Check the version of the templ module.
if err := modcheck.Check(cmd.Args.Path); err != nil {
cmd.Log.Warn("templ version check: " + err.Error())
}
fseh := NewFSEventHandler(
cmd.Log,
cmd.Args.Path,
cmd.Args.Watch,
opts,
cmd.Args.GenerateSourceMapVisualisations,
cmd.Args.KeepOrphanedFiles,
cmd.Args.FileWriter,
cmd.Args.Lazy,
)
// If we're processing a single file, don't bother setting up the channels/multithreaing.
if cmd.Args.FileName != "" {
_, err = fseh.HandleEvent(ctx, fsnotify.Event{
Name: cmd.Args.FileName,
Op: fsnotify.Create,
})
return err
}
// Start timer.
start := time.Now()
// Create channels:
// For the initial filesystem walk and subsequent (optional) fsnotify events.
events := make(chan fsnotify.Event)
// Count of events currently being processed by the event handler.
var eventsWG sync.WaitGroup
// Used to check that the event handler has completed.
var eventHandlerWG sync.WaitGroup
// For errs from the watcher.
errs := make(chan error)
// Tracks whether errors occurred during the generation process.
var errorCount atomic.Int64
// For triggering actions after generation has completed.
postGeneration := make(chan *GenerationEvent, 256)
// Used to check that the post-generation handler has completed.
var postGenerationWG sync.WaitGroup
var postGenerationEventsWG sync.WaitGroup
// Waitgroup for the push process.
var pushHandlerWG sync.WaitGroup
// Start process to push events into the channel.
pushHandlerWG.Add(1)
go func() {
defer pushHandlerWG.Done()
defer close(events)
cmd.Log.Debug(
"Walking directory",
slog.String("path", cmd.Args.Path),
slog.Bool("devMode", cmd.Args.Watch),
)
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil {
cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
return
}
if !cmd.Args.Watch {
cmd.Log.Debug("Dev mode not enabled, process can finish early")
return
}
cmd.Log.Info("Watching files")
rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.WatchPattern, events, errs)
if err != nil {
cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)}
return
}
cmd.Log.Debug("Waiting for context to be cancelled to stop watching files")
<-ctx.Done()
cmd.Log.Debug("Context cancelled, closing watcher")
if err := rw.Close(); err != nil {
cmd.Log.Error("Failed to close watcher", slog.Any("error", err))
}
cmd.Log.Debug("Waiting for events to be processed")
eventsWG.Wait()
cmd.Log.Debug(
"All pending events processed, waiting for pending post-generation events to complete",
)
postGenerationEventsWG.Wait()
cmd.Log.Debug(
"All post-generation events processed, deleting watch mode text files",
slog.Int64("errorCount", errorCount.Load()),
)
fileEvents := make(chan fsnotify.Event)
go func() {
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, fileEvents); err != nil {
cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
return
}
close(fileEvents)
}()
for event := range fileEvents {
if strings.HasSuffix(event.Name, "_templ.txt") {
if err = os.Remove(event.Name); err != nil {
cmd.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
}
}
}
}()
// Start process to handle events.
eventHandlerWG.Add(1)
sem := make(chan struct{}, cmd.Args.WorkerCount)
go func() {
defer eventHandlerWG.Done()
defer close(postGeneration)
cmd.Log.Debug("Starting event handler")
for event := range events {
eventsWG.Add(1)
sem <- struct{}{}
go func(event fsnotify.Event) {
cmd.Log.Debug("Processing file", slog.String("file", event.Name))
defer eventsWG.Done()
defer func() { <-sem }()
r, err := fseh.HandleEvent(ctx, event)
if err != nil {
errs <- err
}
if !(r.GoUpdated || r.TextUpdated) {
cmd.Log.Debug("File not updated", slog.String("file", event.Name))
return
}
e := &GenerationEvent{
Event: event,
Updated: r.Updated,
GoUpdated: r.GoUpdated,
TextUpdated: r.TextUpdated,
}
cmd.Log.Debug("File updated", slog.String("file", event.Name))
postGeneration <- e
}(event)
}
// Wait for all events to be processed before closing.
eventsWG.Wait()
}()
// Start process to handle post-generation events.
var updates int
postGenerationWG.Add(1)
var firstPostGenerationExecuted bool
go func() {
defer close(errs)
defer postGenerationWG.Done()
cmd.Log.Debug("Starting post-generation handler")
timeout := time.NewTimer(time.Hour * 24 * 365)
var goUpdated, textUpdated bool
var p *proxy.Handler
for {
select {
case ge := <-postGeneration:
if ge == nil {
cmd.Log.Debug("Post-generation event channel closed, exiting")
return
}
goUpdated = goUpdated || ge.GoUpdated
textUpdated = textUpdated || ge.TextUpdated
if goUpdated || textUpdated {
updates++
}
// Reset timer.
if !timeout.Stop() {
<-timeout.C
}
timeout.Reset(time.Millisecond * 100)
case <-timeout.C:
if !goUpdated && !textUpdated {
// Nothing to process, reset timer and wait again.
timeout.Reset(time.Hour * 24 * 365)
break
}
postGenerationEventsWG.Add(1)
if cmd.Args.Command != "" && goUpdated {
cmd.Log.Debug("Executing command", slog.String("command", cmd.Args.Command))
if cmd.Args.Watch {
os.Setenv("TEMPL_DEV_MODE", "true")
}
if _, err := run.Run(ctx, cmd.Args.Path, cmd.Args.Command); err != nil {
cmd.Log.Error("Error executing command", slog.Any("error", err))
}
}
if !firstPostGenerationExecuted {
cmd.Log.Debug("First post-generation event received, starting proxy")
firstPostGenerationExecuted = true
p, err = cmd.StartProxy(ctx)
if err != nil {
cmd.Log.Error("Failed to start proxy", slog.Any("error", err))
}
}
// Send server-sent event.
if p != nil && (textUpdated || goUpdated) {
cmd.Log.Debug("Sending reload event")
p.SendSSE("message", "reload")
}
postGenerationEventsWG.Done()
// Reset timer.
timeout.Reset(time.Millisecond * 100)
textUpdated = false
goUpdated = false
}
}
}()
// Read errors.
for err := range errs {
if err == nil {
continue
}
if errors.Is(err, FatalError{}) {
cmd.Log.Debug("Fatal error, exiting")
return err
}
cmd.Log.Error("Error", slog.Any("error", err))
errorCount.Add(1)
}
// Wait for everything to complete.
cmd.Log.Debug("Waiting for push handler to complete")
pushHandlerWG.Wait()
cmd.Log.Debug("Waiting for event handler to complete")
eventHandlerWG.Wait()
cmd.Log.Debug("Waiting for post-generation handler to complete")
postGenerationWG.Wait()
if cmd.Args.Command != "" {
cmd.Log.Debug("Killing command", slog.String("command", cmd.Args.Command))
if err := run.KillAll(); err != nil {
cmd.Log.Error("Error killing command", slog.Any("error", err))
}
}
// Check for errors after everything has completed.
if errorCount.Load() > 0 {
return fmt.Errorf("generation completed with %d errors", errorCount.Load())
}
cmd.Log.Info(
"Complete",
slog.Int("updates", updates),
slog.Duration("duration", time.Since(start)),
)
return nil
}
func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err error) {
if cmd.Args.Proxy == "" {
cmd.Log.Debug("No proxy URL specified, not starting proxy")
return nil, nil
}
var target *url.URL
target, err = url.Parse(cmd.Args.Proxy)
if err != nil {
return nil, FatalError{Err: fmt.Errorf("failed to parse proxy URL: %w", err)}
}
if cmd.Args.ProxyPort == 0 {
cmd.Args.ProxyPort = 7331
}
if cmd.Args.ProxyBind == "" {
cmd.Args.ProxyBind = "127.0.0.1"
}
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
go func() {
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
cmd.Log.Error("Proxy failed", slog.Any("error", err))
}
}()
if !cmd.Args.OpenBrowser {
cmd.Log.Debug("Not opening browser")
return p, nil
}
go func() {
cmd.Log.Debug("Waiting for proxy to be ready", slog.String("url", p.URL))
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = time.Second
var client http.Client
client.Timeout = 1 * time.Second
for {
if _, err := client.Get(p.URL); err == nil {
break
}
d := backoff.NextBackOff()
cmd.Log.Debug(
"Proxy not ready, retrying",
slog.String("url", p.URL),
slog.Any("backoff", d),
)
time.Sleep(d)
}
if err := browser.OpenURL(p.URL); err != nil {
cmd.Log.Error("Failed to open browser", slog.Any("error", err))
}
}()
return p, nil
}

View File

@@ -0,0 +1,366 @@
package generatecmd
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"fmt"
"go/format"
"go/scanner"
"go/token"
"io"
"log/slog"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/a-h/templ/cmd/templ/visualize"
"github.com/a-h/templ/generator"
"github.com/a-h/templ/parser/v2"
"github.com/fsnotify/fsnotify"
)
type FileWriterFunc func(name string, contents []byte) error
func FileWriter(fileName string, contents []byte) error {
return os.WriteFile(fileName, contents, 0o644)
}
func WriterFileWriter(w io.Writer) FileWriterFunc {
return func(_ string, contents []byte) error {
_, err := w.Write(contents)
return err
}
}
func NewFSEventHandler(
log *slog.Logger,
dir string,
devMode bool,
genOpts []generator.GenerateOpt,
genSourceMapVis bool,
keepOrphanedFiles bool,
fileWriter FileWriterFunc,
lazy bool,
) *FSEventHandler {
if !path.IsAbs(dir) {
dir, _ = filepath.Abs(dir)
}
fseh := &FSEventHandler{
Log: log,
dir: dir,
fileNameToLastModTime: make(map[string]time.Time),
fileNameToLastModTimeMutex: &sync.Mutex{},
fileNameToError: make(map[string]struct{}),
fileNameToErrorMutex: &sync.Mutex{},
fileNameToOutput: make(map[string]generator.GeneratorOutput),
fileNameToOutputMutex: &sync.Mutex{},
devMode: devMode,
hashes: make(map[string][sha256.Size]byte),
hashesMutex: &sync.Mutex{},
genOpts: genOpts,
genSourceMapVis: genSourceMapVis,
keepOrphanedFiles: keepOrphanedFiles,
writer: fileWriter,
lazy: lazy,
}
return fseh
}
type FSEventHandler struct {
Log *slog.Logger
// dir is the root directory being processed.
dir string
fileNameToLastModTime map[string]time.Time
fileNameToLastModTimeMutex *sync.Mutex
fileNameToError map[string]struct{}
fileNameToErrorMutex *sync.Mutex
fileNameToOutput map[string]generator.GeneratorOutput
fileNameToOutputMutex *sync.Mutex
devMode bool
hashes map[string][sha256.Size]byte
hashesMutex *sync.Mutex
genOpts []generator.GenerateOpt
genSourceMapVis bool
Errors []error
keepOrphanedFiles bool
writer func(string, []byte) error
lazy bool
}
type GenerateResult struct {
// Updated indicates that the file was updated.
Updated bool
// GoUpdated indicates that Go expressions were updated.
GoUpdated bool
// TextUpdated indicates that text literals were updated.
TextUpdated bool
}
func (h *FSEventHandler) HandleEvent(ctx context.Context, event fsnotify.Event) (result GenerateResult, err error) {
// Handle _templ.go files.
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.go") {
_, err = os.Stat(strings.TrimSuffix(event.Name, "_templ.go") + ".templ")
if !os.IsNotExist(err) {
return GenerateResult{}, err
}
// File is orphaned.
if h.keepOrphanedFiles {
return GenerateResult{}, nil
}
h.Log.Debug("Deleting orphaned Go file", slog.String("file", event.Name))
if err = os.Remove(event.Name); err != nil {
h.Log.Warn("Failed to remove orphaned file", slog.Any("error", err))
}
return GenerateResult{Updated: true, GoUpdated: true, TextUpdated: false}, nil
}
// Handle _templ.txt files.
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.txt") {
if h.devMode {
// Don't delete the file in dev mode, ignore changes to it, since the .templ file
// must have been updated in order to trigger a change in the _templ.txt file.
return GenerateResult{Updated: false, GoUpdated: false, TextUpdated: false}, nil
}
h.Log.Debug("Deleting watch mode file", slog.String("file", event.Name))
if err = os.Remove(event.Name); err != nil {
h.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
return GenerateResult{}, nil
}
return GenerateResult{}, nil
}
// If the file hasn't been updated since the last time we processed it, ignore it.
lastModTime, updatedModTime := h.UpsertLastModTime(event.Name)
if !updatedModTime {
h.Log.Debug("Skipping file because it wasn't updated", slog.String("file", event.Name))
return GenerateResult{}, nil
}
// Process anything that isn't a templ file.
if !strings.HasSuffix(event.Name, ".templ") {
// If it's a Go file, mark it as updated.
if strings.HasSuffix(event.Name, ".go") {
result.GoUpdated = true
}
result.Updated = true
return result, nil
}
// Handle templ files.
// If the go file is newer than the templ file, skip generation, because it's up-to-date.
if h.lazy && goFileIsUpToDate(event.Name, lastModTime) {
h.Log.Debug("Skipping file because the Go file is up-to-date", slog.String("file", event.Name))
return GenerateResult{}, nil
}
// Start a processor.
start := time.Now()
var diag []parser.Diagnostic
result, diag, err = h.generate(ctx, event.Name)
if err != nil {
h.SetError(event.Name, true)
return result, fmt.Errorf("failed to generate code for %q: %w", event.Name, err)
}
if len(diag) > 0 {
for _, d := range diag {
h.Log.Warn(d.Message,
slog.String("from", fmt.Sprintf("%d:%d", d.Range.From.Line, d.Range.From.Col)),
slog.String("to", fmt.Sprintf("%d:%d", d.Range.To.Line, d.Range.To.Col)),
)
}
return result, nil
}
if errorCleared, errorCount := h.SetError(event.Name, false); errorCleared {
h.Log.Info("Error cleared", slog.String("file", event.Name), slog.Int("errors", errorCount))
}
h.Log.Debug("Generated code", slog.String("file", event.Name), slog.Duration("in", time.Since(start)))
return result, nil
}
func goFileIsUpToDate(templFileName string, templFileLastMod time.Time) (upToDate bool) {
goFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ.go"
goFileInfo, err := os.Stat(goFileName)
if err != nil {
return false
}
return goFileInfo.ModTime().After(templFileLastMod)
}
func (h *FSEventHandler) SetError(fileName string, hasError bool) (previouslyHadError bool, errorCount int) {
h.fileNameToErrorMutex.Lock()
defer h.fileNameToErrorMutex.Unlock()
_, previouslyHadError = h.fileNameToError[fileName]
delete(h.fileNameToError, fileName)
if hasError {
h.fileNameToError[fileName] = struct{}{}
}
return previouslyHadError, len(h.fileNameToError)
}
func (h *FSEventHandler) UpsertLastModTime(fileName string) (modTime time.Time, updated bool) {
fileInfo, err := os.Stat(fileName)
if err != nil {
return modTime, false
}
h.fileNameToLastModTimeMutex.Lock()
defer h.fileNameToLastModTimeMutex.Unlock()
previousModTime := h.fileNameToLastModTime[fileName]
currentModTime := fileInfo.ModTime()
if !currentModTime.After(previousModTime) {
return currentModTime, false
}
h.fileNameToLastModTime[fileName] = currentModTime
return currentModTime, true
}
func (h *FSEventHandler) UpsertHash(fileName string, hash [sha256.Size]byte) (updated bool) {
h.hashesMutex.Lock()
defer h.hashesMutex.Unlock()
lastHash := h.hashes[fileName]
if lastHash == hash {
return false
}
h.hashes[fileName] = hash
return true
}
// generate Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func (h *FSEventHandler) generate(ctx context.Context, fileName string) (result GenerateResult, diagnostics []parser.Diagnostic, err error) {
t, err := parser.Parse(fileName)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("%s parsing error: %w", fileName, err)
}
targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go"
// Only use relative filenames to the basepath for filenames in runtime error messages.
absFilePath, err := filepath.Abs(fileName)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("failed to get absolute path for %q: %w", fileName, err)
}
relFilePath, err := filepath.Rel(h.dir, absFilePath)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("failed to get relative path for %q: %w", fileName, err)
}
// Convert Windows file paths to Unix-style for consistency.
relFilePath = filepath.ToSlash(relFilePath)
var b bytes.Buffer
generatorOutput, err := generator.Generate(t, &b, append(h.genOpts, generator.WithFileName(relFilePath))...)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("%s generation error: %w", fileName, err)
}
formattedGoCode, err := format.Source(b.Bytes())
if err != nil {
err = remapErrorList(err, generatorOutput.SourceMap, fileName)
return GenerateResult{}, nil, fmt.Errorf("%s source formatting error %w", fileName, err)
}
// Hash output, and write out the file if the goCodeHash has changed.
goCodeHash := sha256.Sum256(formattedGoCode)
if h.UpsertHash(targetFileName, goCodeHash) {
result.Updated = true
if err = h.writer(targetFileName, formattedGoCode); err != nil {
return result, nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err)
}
}
// Add the txt file if it has changed.
if h.devMode {
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
joined := strings.Join(generatorOutput.Literals, "\n")
txtHash := sha256.Sum256([]byte(joined))
if h.UpsertHash(txtFileName, txtHash) {
result.TextUpdated = true
if err = os.WriteFile(txtFileName, []byte(joined), 0o644); err != nil {
return result, nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err)
}
}
// Check whether the change would require a recompilation to take effect.
h.fileNameToOutputMutex.Lock()
defer h.fileNameToOutputMutex.Unlock()
previous := h.fileNameToOutput[fileName]
if generator.HasChanged(previous, generatorOutput) {
result.GoUpdated = true
}
h.fileNameToOutput[fileName] = generatorOutput
}
parsedDiagnostics, err := parser.Diagnose(t)
if err != nil {
return result, nil, fmt.Errorf("%s diagnostics error: %w", fileName, err)
}
if h.genSourceMapVis {
err = generateSourceMapVisualisation(ctx, fileName, targetFileName, generatorOutput.SourceMap)
}
return result, parsedDiagnostics, err
}
// Takes an error from the formatter and attempts to convert the positions reported in the target file to their positions
// in the source file.
func remapErrorList(err error, sourceMap *parser.SourceMap, fileName string) error {
list, ok := err.(scanner.ErrorList)
if !ok || len(list) == 0 {
return err
}
for i, e := range list {
// The positions in the source map are off by one line because of the package definition.
srcPos, ok := sourceMap.SourcePositionFromTarget(uint32(e.Pos.Line-1), uint32(e.Pos.Column))
if !ok {
continue
}
list[i].Pos = token.Position{
Filename: fileName,
Offset: int(srcPos.Index),
Line: int(srcPos.Line) + 1,
Column: int(srcPos.Col),
}
}
return list
}
func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error {
if err := ctx.Err(); err != nil {
return err
}
var templContents, goContents []byte
var templErr, goErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
templContents, templErr = os.ReadFile(templFileName)
}()
go func() {
defer wg.Done()
goContents, goErr = os.ReadFile(goFileName)
}()
wg.Wait()
if templErr != nil {
return templErr
}
if goErr != nil {
return templErr
}
targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html"
w, err := os.Create(targetFileName)
if err != nil {
return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err)
}
defer w.Close()
b := bufio.NewWriter(w)
defer b.Flush()
return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b)
}

View File

@@ -0,0 +1,23 @@
package generatecmd
type FatalError struct {
Err error
}
func (e FatalError) Error() string {
return e.Err.Error()
}
func (e FatalError) Unwrap() error {
return e.Err
}
func (e FatalError) Is(target error) bool {
_, ok := target.(FatalError)
return ok
}
func (e FatalError) As(target any) bool {
_, ok := target.(*FatalError)
return ok
}

View File

@@ -0,0 +1,39 @@
package generatecmd
import (
"context"
_ "embed"
"log/slog"
_ "net/http/pprof"
)
type Arguments struct {
FileName string
FileWriter FileWriterFunc
Path string
Watch bool
WatchPattern string
OpenBrowser bool
Command string
ProxyBind string
ProxyPort int
Proxy string
NotifyProxy bool
WorkerCount int
GenerateSourceMapVisualisations bool
IncludeVersion bool
IncludeTimestamp bool
// PPROFPort is the port to run the pprof server on.
PPROFPort int
KeepOrphanedFiles bool
Lazy bool
}
func Run(ctx context.Context, log *slog.Logger, args Arguments) (err error) {
g, err := NewGenerate(log, args)
if err != nil {
return err
}
return g.Run(ctx)
}

View File

@@ -0,0 +1,170 @@
package generatecmd
import (
"context"
"io"
"log/slog"
"os"
"path"
"regexp"
"testing"
"time"
"github.com/a-h/templ/cmd/templ/testproject"
"golang.org/x/sync/errgroup"
)
func TestGenerate(t *testing.T) {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
t.Run("can generate a file in place", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
// Run the generate command.
err = Run(context.Background(), log, Arguments{
FileName: path.Join(dir, "templates.templ"),
})
if err != nil {
t.Fatalf("failed to run generate command: %v", err)
}
// Check the templates_templ.go file was created.
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("templates_templ.go was not created: %v", err)
}
})
t.Run("can generate a file in watch mode", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
var eg errgroup.Group
eg.Go(func() error {
// Run the generate command.
return Run(ctx, log, Arguments{
Path: dir,
Watch: true,
})
})
// Check the templates_templ.go file was created, with backoff.
for i := 0; i < 5; i++ {
time.Sleep(time.Second * time.Duration(i))
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
if err != nil {
continue
}
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
if err != nil {
continue
}
break
}
if err != nil {
t.Fatalf("template files were not created: %v", err)
}
cancel()
if err := eg.Wait(); err != nil {
t.Fatalf("generate command failed: %v", err)
}
// Check the templates_templ.txt file was removed.
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
if err == nil {
t.Fatalf("templates_templ.txt was not removed")
}
})
}
func TestDefaultWatchPattern(t *testing.T) {
tests := []struct {
name string
input string
matches bool
}{
{
name: "empty file names do not match",
input: "",
matches: false,
},
{
name: "*_templ.txt matches, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\strings_templ.txt`,
matches: true,
},
{
name: "*_templ.txt matches, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/strings_templ.txt",
matches: true,
},
{
name: "*.templ files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.templ`,
matches: true,
},
{
name: "*.templ files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.templ",
matches: true,
},
{
name: "*_templ.go files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates_templ.go`,
matches: true,
},
{
name: "*_templ.go files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates_templ.go",
matches: true,
},
{
name: "*.go files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.go`,
matches: true,
},
{
name: "*.go files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.go",
matches: true,
},
{
name: "*.css files do not match",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.css",
matches: false,
},
}
wpRegexp, err := regexp.Compile(defaultWatchPattern)
if err != nil {
t.Fatalf("failed to compile default watch pattern: %v", err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
if wpRegexp.MatchString(test.input) != test.matches {
t.Fatalf("expected match of %q to be %v", test.input, test.matches)
}
})
}
}

View File

@@ -0,0 +1,82 @@
package modcheck
import (
"fmt"
"os"
"path/filepath"
"github.com/a-h/templ"
"golang.org/x/mod/modfile"
"golang.org/x/mod/semver"
)
// WalkUp the directory tree, starting at dir, until we find a directory containing
// a go.mod file.
func WalkUp(dir string) (string, error) {
dir, err := filepath.Abs(dir)
if err != nil {
return "", fmt.Errorf("failed to get absolute path: %w", err)
}
var modFile string
for {
modFile = filepath.Join(dir, "go.mod")
_, err := os.Stat(modFile)
if err != nil && !os.IsNotExist(err) {
return "", fmt.Errorf("failed to stat go.mod file: %w", err)
}
if os.IsNotExist(err) {
// Move up.
prev := dir
dir = filepath.Dir(dir)
if dir == prev {
break
}
continue
}
break
}
// No file found.
if modFile == "" {
return dir, fmt.Errorf("could not find go.mod file")
}
return dir, nil
}
func Check(dir string) error {
dir, err := WalkUp(dir)
if err != nil {
return err
}
// Found a go.mod file.
// Read it and find the templ version.
modFile := filepath.Join(dir, "go.mod")
m, err := os.ReadFile(modFile)
if err != nil {
return fmt.Errorf("failed to read go.mod file: %w", err)
}
mf, err := modfile.Parse(modFile, m, nil)
if err != nil {
return fmt.Errorf("failed to parse go.mod file: %w", err)
}
if mf.Module.Mod.Path == "github.com/a-h/templ" {
// The go.mod file is for templ itself.
return nil
}
for _, r := range mf.Require {
if r.Mod.Path == "github.com/a-h/templ" {
cmp := semver.Compare(r.Mod.Version, templ.Version())
if cmp < 0 {
return fmt.Errorf("generator %v is newer than templ version %v found in go.mod file, consider running `go get -u github.com/a-h/templ` to upgrade", templ.Version(), r.Mod.Version)
}
if cmp > 0 {
return fmt.Errorf("generator %v is older than templ version %v found in go.mod file, consider upgrading templ CLI", templ.Version(), r.Mod.Version)
}
return nil
}
}
return fmt.Errorf("templ not found in go.mod file, run `go get github.com/a-h/templ` to install it")
}

View File

@@ -0,0 +1,47 @@
package modcheck
import (
"testing"
"golang.org/x/mod/modfile"
)
func TestPatchGoVersion(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: "go 1.20",
expected: "1.20",
},
{
input: "go 1.20.123",
expected: "1.20.123",
},
{
input: "go 1.20.1",
expected: "1.20.1",
},
{
input: "go 1.20rc1",
expected: "1.20rc1",
},
{
input: "go 1.15",
expected: "1.15",
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
input := "module github.com/a-h/templ\n\n" + string(test.input) + "\n" + "toolchain go1.27.9\n"
mf, err := modfile.Parse("go.mod", []byte(input), nil)
if err != nil {
t.Fatalf("failed to parse go.mod: %v", err)
}
if test.expected != mf.Go.Version {
t.Errorf("expected %q, got %q", test.expected, mf.Go.Version)
}
})
}
}

View File

@@ -0,0 +1,284 @@
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"html"
"io"
stdlog "log"
"log/slog"
"math"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
"github.com/andybalholm/brotli"
_ "embed"
)
//go:embed script.js
var script string
type Handler struct {
log *slog.Logger
URL string
Target *url.URL
p *httputil.ReverseProxy
sse *sse.Handler
}
func getScriptTag(nonce string) string {
if nonce != "" {
var sb strings.Builder
sb.WriteString(`<script src="/_templ/reload/script.js" nonce="`)
sb.WriteString(html.EscapeString(nonce))
sb.WriteString(`"></script>`)
return sb.String()
}
return `<script src="/_templ/reload/script.js"></script>`
}
func insertScriptTagIntoBody(nonce, body string) (updated string) {
doc, err := goquery.NewDocumentFromReader(strings.NewReader(body))
if err != nil {
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
}
doc.Find("body").AppendHtml(getScriptTag(nonce))
r, err := doc.Html()
if err != nil {
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
}
return r
}
type passthroughWriteCloser struct {
io.Writer
}
func (pwc passthroughWriteCloser) Close() error {
return nil
}
const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."
func (h *Handler) modifyResponse(r *http.Response) error {
log := h.log.With(slog.String("url", r.Request.URL.String()))
if r.Header.Get("templ-skip-modify") == "true" {
log.Debug("Skipping response modification because templ-skip-modify header is set")
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
log.Debug("Skipping response modification because content type is not text/html", slog.String("content-type", contentType))
return nil
}
// Set up readers and writers.
newReader := func(in io.Reader) (out io.Reader, err error) {
return in, nil
}
newWriter := func(out io.Writer) io.WriteCloser {
return passthroughWriteCloser{out}
}
switch r.Header.Get("Content-Encoding") {
case "gzip":
newReader = func(in io.Reader) (out io.Reader, err error) {
return gzip.NewReader(in)
}
newWriter = func(out io.Writer) io.WriteCloser {
return gzip.NewWriter(out)
}
case "br":
newReader = func(in io.Reader) (out io.Reader, err error) {
return brotli.NewReader(in), nil
}
newWriter = func(out io.Writer) io.WriteCloser {
return brotli.NewWriter(out)
}
case "":
log.Debug("No content encoding header found")
default:
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
}
// Read the encoded body.
encr, err := newReader(r.Body)
if err != nil {
return err
}
defer r.Body.Close()
body, err := io.ReadAll(encr)
if err != nil {
return err
}
// Update it.
csp := r.Header.Get("Content-Security-Policy")
updated := insertScriptTagIntoBody(parseNonce(csp), string(body))
if log.Enabled(r.Request.Context(), slog.LevelDebug) {
if len(updated) == len(body) {
log.Debug("Reload script not inserted")
} else {
log.Debug("Reload script inserted")
}
}
// Encode the response.
var buf bytes.Buffer
encw := newWriter(&buf)
_, err = encw.Write([]byte(updated))
if err != nil {
return err
}
err = encw.Close()
if err != nil {
return err
}
// Update the response.
r.Body = io.NopCloser(&buf)
r.ContentLength = int64(buf.Len())
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
return nil
}
func parseNonce(csp string) (nonce string) {
outer:
for _, rawDirective := range strings.Split(csp, ";") {
parts := strings.Fields(rawDirective)
if len(parts) < 2 {
continue
}
if parts[0] != "script-src" {
continue
}
for _, source := range parts[1:] {
source = strings.TrimPrefix(source, "'")
source = strings.TrimSuffix(source, "'")
if strings.HasPrefix(source, "nonce-") {
nonce = source[6:]
break outer
}
}
}
return nonce
}
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
p := httputil.NewSingleHostReverseProxy(target)
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
p.Transport = &roundTripper{
maxRetries: 20,
initialDelay: 100 * time.Millisecond,
backoffExponent: 1.5,
}
h = &Handler{
log: log,
URL: fmt.Sprintf("http://%s:%d", bind, port),
Target: target,
p: p,
sse: sse.New(),
}
p.ModifyResponse = h.modifyResponse
return h
}
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/_templ/reload/script.js" {
// Provides a script that reloads the page.
w.Header().Add("Content-Type", "text/javascript")
_, err := io.WriteString(w, script)
if err != nil {
fmt.Printf("failed to write script: %v\n", err)
}
return
}
if r.URL.Path == "/_templ/reload/events" {
switch r.Method {
case http.MethodGet:
// Provides a list of messages including a reload message.
p.sse.ServeHTTP(w, r)
return
case http.MethodPost:
// Send a reload message to all connected clients.
p.sse.Send("message", "reload")
return
}
http.Error(w, "only GET or POST method allowed", http.StatusMethodNotAllowed)
return
}
p.p.ServeHTTP(w, r)
}
func (p *Handler) SendSSE(eventType string, data string) {
p.sse.Send(eventType, data)
}
type roundTripper struct {
maxRetries int
initialDelay time.Duration
backoffExponent float64
}
func (rt *roundTripper) setShouldSkipResponseModificationHeader(r *http.Request, resp *http.Response) {
// Instruct the modifyResponse function to skip modifying the response if the
// HTTP request has come from HTMX.
if r.Header.Get("HX-Request") != "true" {
return
}
resp.Header.Set("templ-skip-modify", "true")
}
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// Read and buffer the body.
var bodyBytes []byte
if r.Body != nil && r.Body != http.NoBody {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body.Close()
}
// Retry logic.
var resp *http.Response
var err error
for retries := 0; retries < rt.maxRetries; retries++ {
// Clone the request and set the body.
req := r.Clone(r.Context())
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Execute the request.
resp, err = http.DefaultTransport.RoundTrip(req)
if err != nil {
time.Sleep(rt.initialDelay * time.Duration(math.Pow(rt.backoffExponent, float64(retries))))
continue
}
rt.setShouldSkipResponseModificationHeader(r, resp)
return resp, nil
}
return nil, fmt.Errorf("max retries reached: %q", r.URL.String())
}
func NotifyProxy(host string, port int) error {
urlStr := fmt.Sprintf("http://%s:%d/_templ/reload/events", host, port)
req, err := http.NewRequest(http.MethodPost, urlStr, nil)
if err != nil {
return err
}
_, err = http.DefaultClient.Do(req)
return err
}

View File

@@ -0,0 +1,627 @@
package proxy
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/andybalholm/brotli"
"github.com/google/go-cmp/cmp"
)
func TestRoundTripper(t *testing.T) {
t.Run("if the HX-Request header is present, set the templ-skip-modify header on the response", func(t *testing.T) {
rt := &roundTripper{}
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("unexpected error creating request: %v", err)
}
req.Header.Set("HX-Request", "true")
resp := &http.Response{Header: make(http.Header)}
rt.setShouldSkipResponseModificationHeader(req, resp)
if resp.Header.Get("templ-skip-modify") != "true" {
t.Errorf("expected templ-skip-modify header to be true, got %v", resp.Header.Get("templ-skip-modify"))
}
})
t.Run("if the HX-Request header is not present, do not set the templ-skip-modify header on the response", func(t *testing.T) {
rt := &roundTripper{}
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("unexpected error creating request: %v", err)
}
resp := &http.Response{Header: make(http.Header)}
rt.setShouldSkipResponseModificationHeader(req, resp)
if resp.Header.Get("templ-skip-modify") != "" {
t.Errorf("expected templ-skip-modify header to be empty, got %v", resp.Header.Get("templ-skip-modify"))
}
})
}
func TestProxy(t *testing.T) {
t.Run("plain: non-html content is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Content-Length", "16")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "16" {
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: if the response contains templ-skip-modify header, it is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`Hello`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html")
r.Header.Set("Content-Length", "5")
r.Header.Set("templ-skip-modify", "true")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "5" {
t.Errorf("expected content length to be 5, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`Hello`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
expectedString := insertScriptTagIntoBody("", `<html><body></body></html>`)
if !strings.Contains(expectedString, getScriptTag("")) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted with nonce", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
const nonce = "this-is-the-nonce"
r.Header.Set("Content-Security-Policy", fmt.Sprintf("script-src 'nonce-%s'", nonce))
expectedString := insertScriptTagIntoBody(nonce, `<html><body></body></html>`)
if !strings.Contains(expectedString, getScriptTag(nonce)) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body><script>console.log("<body></body>")</script></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
expectedString := insertScriptTagIntoBody("", `<html><body><script>console.log("<body></body>")</script></body></html>`)
if !strings.Contains(expectedString, getScriptTag("")) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
if !strings.Contains(expectedString, `console.log("<body></body>")`) {
t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("gzip: non-html content is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "application/json")
// It's not actually gzipped here, but it doesn't matter, it shouldn't get that far.
r.Header.Set("Content-Encoding", "gzip")
// Similarly, this is not the actual length of the gzipped content.
r.Header.Set("Content-Length", "16")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "16" {
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("gzip: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
_, err := gzw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
gzw.Close()
expectedString := insertScriptTagIntoBody("", body)
var expectedBytes bytes.Buffer
gzw = gzip.NewWriter(&expectedBytes)
_, err = gzw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
gzw.Close()
expectedLength := len(expectedBytes.Bytes())
r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "gzip")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}
gr, err := gzip.NewReader(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
actualBody, err := io.ReadAll(gr)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
brw := brotli.NewWriter(&buf)
_, err := brw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedString := insertScriptTagIntoBody("", body)
var expectedBytes bytes.Buffer
brw = brotli.NewWriter(&expectedBytes)
_, err = brw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedLength := len(expectedBytes.Bytes())
r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "br")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
// Arrange 1: create a test proxy server.
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
dummyServer := httptest.NewServer(http.HandlerFunc(dummyHandler))
defer dummyServer.Close()
u, err := url.Parse(dummyServer.URL)
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
handler := New(log, "0.0.0.0", 0, u)
proxyServer := httptest.NewServer(handler)
defer proxyServer.Close()
u2, err := url.Parse(proxyServer.URL)
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
port, err := strconv.Atoi(u2.Port())
if err != nil {
t.Fatalf("unexpected error parsing port: %v", err)
}
// Arrange 2: start a goroutine to listen for sse events.
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
errChan := make(chan error)
sseRespCh := make(chan string)
sseListening := make(chan bool) // Coordination channel that ensures the SSE listener is started before notifying the proxy.
go func() {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/_templ/reload/events", proxyServer.URL), nil)
if err != nil {
errChan <- err
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
errChan <- err
return
}
defer resp.Body.Close()
sseListening <- true
lines := []string{}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
lines = append(lines, scanner.Text())
if scanner.Text() == "data: reload" {
sseRespCh <- strings.Join(lines, "\n")
return
}
}
err = scanner.Err()
if err != nil {
errChan <- err
return
}
}()
// Act: notify the proxy.
select { // Either SSE is listening or an error occurred.
case <-sseListening:
err = NotifyProxy(u2.Hostname(), port)
if err != nil {
t.Fatalf("unexpected error notifying proxy: %v", err)
}
case err := <-errChan:
if err == nil {
t.Fatalf("unexpected sse response: %v", err)
}
}
// Assert.
select { // Either SSE has a expected response or an error or timeout occurred.
case resp := <-sseRespCh:
if !strings.Contains(resp, "event: message\ndata: reload") {
t.Errorf("expected sse reload event to be received, got: %q", resp)
}
case err := <-errChan:
if err == nil {
t.Fatalf("unexpected sse response: %v", err)
}
case <-ctx.Done():
t.Fatalf("timeout waiting for sse response")
}
})
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "weird-encoding")
// Act
lh := newTestLogHandler(slog.LevelInfo)
log := slog.New(lh)
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if len(lh.records) != 1 {
var sb strings.Builder
for _, record := range lh.records {
sb.WriteString(record.Message)
sb.WriteString("\n")
}
t.Fatalf("expected 1 log entry, but got %d: \n%s", len(lh.records), sb.String())
}
record := lh.records[0]
if record.Message != unsupportedContentEncoding {
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
}
if record.Level != slog.LevelWarn {
t.Errorf("expected warning, got level %v", record.Level)
}
})
}
func newTestLogHandler(level slog.Level) *testLogHandler {
return &testLogHandler{
m: new(sync.Mutex),
records: nil,
level: level,
}
}
type testLogHandler struct {
m *sync.Mutex
records []slog.Record
level slog.Level
}
func (h *testLogHandler) Enabled(ctx context.Context, l slog.Level) bool {
return l >= h.level
}
func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
h.m.Lock()
defer h.m.Unlock()
if r.Level < h.level {
return nil
}
h.records = append(h.records, r)
return nil
}
func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h
}
func (h *testLogHandler) WithGroup(name string) slog.Handler {
return h
}
func TestParseNonce(t *testing.T) {
for _, tc := range []struct {
name string
csp string
expected string
}{
{
name: "empty csp",
csp: "",
expected: "",
},
{
name: "simple csp",
csp: "script-src 'nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C'",
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
},
{
name: "simple csp without single quote",
csp: "script-src nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
},
{
name: "complete csp",
csp: "default-src 'self'; frame-ancestors 'self'; form-action 'self'; script-src 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC';",
expected: "4VOtk0Uo1l7pwtC",
},
{
name: "mdn example 1",
csp: "default-src 'self'",
expected: "",
},
{
name: "mdn example 2",
csp: "default-src 'self' *.trusted.com",
expected: "",
},
{
name: "mdn example 3",
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com",
expected: "",
},
{
name: "mdn example 3 multiple sources",
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com foo.com 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC'",
expected: "4VOtk0Uo1l7pwtC",
},
} {
t.Run(tc.name, func(t *testing.T) {
nonce := parseNonce(tc.csp)
if nonce != tc.expected {
t.Errorf("expected nonce to be %s, but got %s", tc.expected, nonce)
}
})
}
}

View File

@@ -0,0 +1,10 @@
(function() {
let templ_reloadSrc = window.templ_reloadSrc || new EventSource("/_templ/reload/events");
templ_reloadSrc.onmessage = (event) => {
if (event && event.data === "reload") {
window.location.reload();
}
};
window.templ_reloadSrc = templ_reloadSrc;
window.onbeforeunload = () => window.templ_reloadSrc.close();
})();

View File

@@ -0,0 +1,108 @@
package run_test
import (
"context"
"embed"
"io"
"net/http"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/a-h/templ/cmd/templ/generatecmd/run"
)
//go:embed testprogram/*
var testprogram embed.FS
func TestGoRun(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}
// Copy testprogram to a temporary directory.
dir, err := os.MkdirTemp("", "testprogram")
if err != nil {
t.Fatalf("failed to make test dir: %v", err)
}
files, err := testprogram.ReadDir("testprogram")
if err != nil {
t.Fatalf("failed to read embedded dir: %v", err)
}
for _, file := range files {
srcFileName := "testprogram/" + file.Name()
srcData, err := testprogram.ReadFile(srcFileName)
if err != nil {
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
}
tgtFileName := filepath.Join(dir, file.Name())
tgtFile, err := os.Create(tgtFileName)
if err != nil {
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
}
defer tgtFile.Close()
if _, err := tgtFile.Write(srcData); err != nil {
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
}
}
// Rename the go.mod.embed file to go.mod.
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
t.Fatalf("failed to rename go.mod.embed: %v", err)
}
tests := []struct {
name string
cmd string
}{
{
name: "Well behaved programs get shut down",
cmd: "go run .",
},
{
name: "Badly behaved programs get shut down",
cmd: "go run . -badly-behaved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cmd, err := run.Run(ctx, dir, tt.cmd)
if err != nil {
t.Fatalf("failed to run program: %v", err)
}
time.Sleep(1 * time.Second)
pid := cmd.Process.Pid
if err := run.KillAll(); err != nil {
t.Fatalf("failed to kill all: %v", err)
}
// Check the parent process is no longer running.
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
t.Fatalf("process %d is still running", pid)
}
// Check that the child was stopped.
body, err := readResponse("http://localhost:7777")
if err == nil {
t.Fatalf("child process is still running: %s", body)
}
})
}
}
func readResponse(url string) (body string, err error) {
resp, err := http.Get(url)
if err != nil {
return body, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return body, err
}
return string(b), nil
}

View File

@@ -0,0 +1,84 @@
//go:build unix
package run
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
var (
m = &sync.Mutex{}
running = map[string]*exec.Cmd{}
)
func KillAll() (err error) {
m.Lock()
defer m.Unlock()
var errs []error
for _, cmd := range running {
if err := kill(cmd); err != nil {
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
}
}
running = map[string]*exec.Cmd{}
return errors.Join(errs...)
}
func kill(cmd *exec.Cmd) (err error) {
errs := make([]error, 4)
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
errs[2] = ignoreExited(cmd.Wait())
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
return errors.Join(errs...)
}
func ignoreExited(err error) error {
if errors.Is(err, syscall.ESRCH) {
return nil
}
// Ignore *exec.ExitError
if _, ok := err.(*exec.ExitError); ok {
return nil
}
return err
}
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
if err := kill(cmd); err != nil {
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
}
delete(running, input)
}
parts := strings.Fields(input)
executable := parts[0]
args := []string{}
if len(parts) > 1 {
args = append(args, parts[1:]...)
}
cmd = exec.CommandContext(ctx, executable, args...)
// Wait for the process to finish gracefully before termination.
cmd.WaitDelay = time.Second * 3
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
running[input] = cmd
err = cmd.Start()
return
}

View File

@@ -0,0 +1,69 @@
//go:build windows
package run
import (
"context"
"os"
"os/exec"
"strconv"
"strings"
"sync"
)
var m = &sync.Mutex{}
var running = map[string]*exec.Cmd{}
func KillAll() (err error) {
m.Lock()
defer m.Unlock()
for _, cmd := range running {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
err := kill.Run()
if err != nil {
return err
}
}
running = map[string]*exec.Cmd{}
return
}
func Stop(cmd *exec.Cmd) (err error) {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
return kill.Run()
}
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
err := kill.Run()
if err != nil {
return cmd, err
}
delete(running, input)
}
parts := strings.Fields(input)
executable := parts[0]
args := []string{}
if len(parts) > 1 {
args = append(args, parts[1:]...)
}
cmd = exec.Command(executable, args...)
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
running[input] = cmd
err = cmd.Start()
return
}

View File

@@ -0,0 +1,3 @@
module testprogram
go 1.23

View File

@@ -0,0 +1,63 @@
package main
import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// This is a test program. It is used only to test the behaviour of the run package.
// The run package is supposed to be able to run and stop programs. Those programs may start
// child processes, which should also be stopped when the parent program is stopped.
// For example, running `go run .` will compile an executable and run it.
// So, this program does nothing. It just waits for a signal to stop.
// In "Well behaved" mode, the program will stop when it receives a signal.
// In "Badly behaved" mode, the program will ignore the signal and continue running.
// The run package should be able to stop the program in both cases.
var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")
func main() {
flag.Parse()
mode := "Well behaved"
if *badlyBehavedFlag {
mode = "Badly behaved"
}
fmt.Printf("%s process %d started.\n", mode, os.Getpid())
// Start a web server on a known port so that we can check that this process is
// not running, when it's been started as a child process, and we don't know
// its pid.
go func() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d", os.Getpid())
})
err := http.ListenAndServe("127.0.0.1:7777", nil)
if err != nil {
fmt.Printf("Error running web server: %v\n", err)
}
}()
sigs := make(chan os.Signal, 1)
if !*badlyBehavedFlag {
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
}
for {
select {
case <-sigs:
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
return
case <-time.After(1 * time.Second):
fmt.Printf("Process %d still running...\n", os.Getpid())
}
}
}

View File

@@ -0,0 +1,84 @@
package sse
import (
_ "embed"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
)
func New() *Handler {
return &Handler{
m: new(sync.Mutex),
requests: map[int64]chan event{},
}
}
type Handler struct {
m *sync.Mutex
counter int64
requests map[int64]chan event
}
type event struct {
Type string
Data string
}
// Send an event to all connected clients.
func (s *Handler) Send(eventType string, data string) {
s.m.Lock()
defer s.m.Unlock()
for _, f := range s.requests {
f := f
go func(f chan event) {
f <- event{
Type: eventType,
Data: data,
}
}(f)
}
}
func (s *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
id := atomic.AddInt64(&s.counter, 1)
s.m.Lock()
events := make(chan event)
s.requests[id] = events
s.m.Unlock()
defer func() {
s.m.Lock()
defer s.m.Unlock()
delete(s.requests, id)
close(events)
}()
timer := time.NewTimer(0)
loop:
for {
select {
case <-timer.C:
if _, err := fmt.Fprintf(w, "event: message\ndata: ping\n\n"); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
timer.Reset(time.Second * 5)
case e := <-events:
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Type, e.Data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
case <-r.Context().Done():
break loop
}
w.(http.Flusher).Flush()
}
}

View File

@@ -0,0 +1,52 @@
package symlink
import (
"context"
"io"
"log/slog"
"os"
"path"
"testing"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/cmd/templ/testproject"
)
func TestSymlink(t *testing.T) {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
t.Run("can generate if root is symlink", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
symlinkPath := dir + "-symlink"
err = os.Symlink(dir, symlinkPath)
if err != nil {
t.Fatalf("failed to create dir symlink: %v", err)
}
defer os.Remove(symlinkPath)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(symlinkPath, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
// Run the generate command.
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
Path: symlinkPath,
})
if err != nil {
t.Fatalf("failed to run generate command: %v", err)
}
// Check the templates_templ.go file was created.
_, err = os.Stat(path.Join(symlinkPath, "templates_templ.go"))
if err != nil {
t.Fatalf("templates_templ.go was not created: %v", err)
}
})
}

View File

@@ -0,0 +1,101 @@
package testeventhandler
import (
"context"
"errors"
"fmt"
"go/scanner"
"go/token"
"io"
"log/slog"
"os"
"testing"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/generator"
"github.com/fsnotify/fsnotify"
"github.com/google/go-cmp/cmp"
)
func TestErrorLocationMapping(t *testing.T) {
tests := []struct {
name string
rawFileName string
errorPositions []token.Position
}{
{
name: "single error outputs location in srcFile",
rawFileName: "single_error.templ.error",
errorPositions: []token.Position{
{Offset: 46, Line: 3, Column: 20},
},
},
{
name: "multiple errors all output locations in srcFile",
rawFileName: "multiple_errors.templ.error",
errorPositions: []token.Position{
{Offset: 41, Line: 3, Column: 15},
{Offset: 101, Line: 7, Column: 22},
{Offset: 126, Line: 10, Column: 1},
},
},
}
slog := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
var fw generatecmd.FileWriterFunc
fseh := generatecmd.NewFSEventHandler(slog, ".", false, []generator.GenerateOpt{}, false, false, fw, false)
for _, test := range tests {
// The raw files cannot end in .templ because they will cause the generator to fail. Instead,
// we create a tmp file that ends in .templ only for the duration of the test.
rawFile, err := os.Open(test.rawFileName)
if err != nil {
t.Errorf("%s: Failed to open file %s: %v", test.name, test.rawFileName, err)
break
}
file, err := os.CreateTemp("", fmt.Sprintf("*%s.templ", test.rawFileName))
if err != nil {
t.Errorf("%s: Failed to create a tmp file at %s: %v", test.name, file.Name(), err)
break
}
defer os.Remove(file.Name())
if _, err = io.Copy(file, rawFile); err != nil {
t.Errorf("%s: Failed to copy contents from raw file %s to tmp %s: %v", test.name, test.rawFileName, file.Name(), err)
}
event := fsnotify.Event{Name: file.Name(), Op: fsnotify.Write}
_, err = fseh.HandleEvent(context.Background(), event)
if err == nil {
t.Errorf("%s: no error was thrown", test.name)
break
}
list, ok := err.(scanner.ErrorList)
for !ok {
err = errors.Unwrap(err)
if err == nil {
t.Errorf("%s: reached end of error wrapping before finding an ErrorList", test.name)
break
} else {
list, ok = err.(scanner.ErrorList)
}
}
if !ok {
break
}
if len(list) != len(test.errorPositions) {
t.Errorf("%s: expected %d errors but got %d", test.name, len(test.errorPositions), len(list))
break
}
for i, err := range list {
test.errorPositions[i].Filename = file.Name()
diff := cmp.Diff(test.errorPositions[i], err.Pos)
if diff != "" {
t.Error(diff)
t.Error("expected:")
t.Error(test.errorPositions[i])
t.Error("actual:")
t.Error(err.Pos)
}
}
}
}

View File

@@ -0,0 +1,10 @@
package testeventhandler
func invalid(a: string) string {
return "foo"
}
templ multipleError(a: string) {
<div/>
}
l

View File

@@ -0,0 +1,5 @@
package testeventhandler
templ singleError(a: string) {
<div/>
}

View File

@@ -0,0 +1,485 @@
package testwatch
import (
"bufio"
"bytes"
"context"
"embed"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
)
//go:embed testdata/*
var testdata embed.FS
func createTestProject(moduleRoot string) (dir string, err error) {
dir, err = os.MkdirTemp("", "templ_watch_test_*")
if err != nil {
return dir, fmt.Errorf("failed to make test dir: %w", err)
}
files, err := testdata.ReadDir("testdata")
if err != nil {
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
}
for _, file := range files {
src := filepath.Join("testdata", file.Name())
data, err := testdata.ReadFile(src)
if err != nil {
return dir, fmt.Errorf("failed to read file: %w", err)
}
target := filepath.Join(dir, file.Name())
if file.Name() == "go.mod.embed" {
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
target = filepath.Join(dir, "go.mod")
}
err = os.WriteFile(target, data, 0660)
if err != nil {
return dir, fmt.Errorf("failed to copy file: %w", err)
}
}
return dir, nil
}
func replaceInFile(name, src, tgt string) error {
data, err := os.ReadFile(name)
if err != nil {
return err
}
updated := strings.Replace(string(data), src, tgt, -1)
return os.WriteFile(name, []byte(updated), 0660)
}
func getPort() (port int, err error) {
var a *net.TCPAddr
if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
}
return
}
func getHTML(url string) (doc *goquery.Document, err error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to get %q: %w", url, err)
}
return goquery.NewDocumentFromReader(resp.Body)
}
func TestCanAccessDirect(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Assert.
doc, err := getHTML(args.AppURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
countText := doc.Find(`div[data-testid="count"]`).Text()
actualCount, err := strconv.Atoi(countText)
if err != nil {
t.Fatalf("got count %q instead of integer", countText)
}
if actualCount < 1 {
t.Errorf("expected count >= 1, got %d", actualCount)
}
}
func TestCanAccessViaProxy(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Assert.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
countText := doc.Find(`div[data-testid="count"]`).Text()
actualCount, err := strconv.Atoi(countText)
if err != nil {
t.Fatalf("got count %q instead of integer", countText)
}
if actualCount < 1 {
t.Errorf("expected count >= 1, got %d", actualCount)
}
}
type Event struct {
Type string
Data string
}
func readSSE(ctx context.Context, url string, sse chan<- Event) (err error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
var e Event
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
sse <- e
e = Event{}
continue
}
if strings.HasPrefix(line, "event: ") {
e.Type = line[len("event: "):]
}
if strings.HasPrefix(line, "data: ") {
e.Data = line[len("data: "):]
}
}
return scanner.Err()
}
func TestFileModificationsResultInSSEWithGzip(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Start the SSE check.
events := make(chan Event)
var eventsErr error
go func() {
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
}()
// Assert data is expected.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
t.Errorf("expected %q, got %q", "Original", text)
}
// Change file.
templFile := filepath.Join(args.AppDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification">Updated</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
// Give the filesystem watcher a few seconds.
var reloadCount int
loop:
for {
select {
case event := <-events:
if event.Data == "reload" {
reloadCount++
break loop
}
case <-time.After(time.Second * 5):
break loop
}
}
if reloadCount == 0 {
t.Error("failed to receive SSE about update after 5 seconds")
}
// Check to see if there were any errors.
if eventsErr != nil {
t.Errorf("error reading events: %v", err)
}
// See results in browser immediately.
doc, err = getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
t.Errorf("expected %q, got %q", "Updated", text)
}
}
func TestFileModificationsResultInSSE(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Start the SSE check.
events := make(chan Event)
var eventsErr error
go func() {
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
}()
// Assert data is expected.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
t.Errorf("expected %q, got %q", "Original", text)
}
// Change file.
templFile := filepath.Join(args.AppDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification">Updated</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
// Give the filesystem watcher a few seconds.
var reloadCount int
loop:
for {
select {
case event := <-events:
if event.Data == "reload" {
reloadCount++
break loop
}
case <-time.After(time.Second * 5):
break loop
}
}
if reloadCount == 0 {
t.Error("failed to receive SSE about update after 5 seconds")
}
// Check to see if there were any errors.
if eventsErr != nil {
t.Errorf("error reading events: %v", err)
}
// See results in browser immediately.
doc, err = getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
t.Errorf("expected %q, got %q", "Updated", text)
}
}
func NewTestArgs(modRoot, appDir string, appPort int, proxyBind string, proxyPort int) TestArgs {
return TestArgs{
ModRoot: modRoot,
AppDir: appDir,
AppPort: appPort,
AppURL: fmt.Sprintf("http://localhost:%d", appPort),
ProxyBind: proxyBind,
ProxyPort: proxyPort,
ProxyURL: fmt.Sprintf("http://%s:%d", proxyBind, proxyPort),
}
}
type TestArgs struct {
ModRoot string
AppDir string
AppPort int
AppURL string
ProxyBind string
ProxyPort int
ProxyURL string
}
func Setup(gzipEncoding bool) (args TestArgs, teardown func(t *testing.T), err error) {
wd, err := os.Getwd()
if err != nil {
return args, teardown, fmt.Errorf("could not find working dir: %w", err)
}
moduleRoot, err := modcheck.WalkUp(wd)
if err != nil {
return args, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
}
appDir, err := createTestProject(moduleRoot)
if err != nil {
return args, teardown, fmt.Errorf("failed to create test project: %v", err)
}
appPort, err := getPort()
if err != nil {
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
}
proxyPort, err := getPort()
if err != nil {
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
}
proxyBind := "localhost"
args = NewTestArgs(moduleRoot, appDir, appPort, proxyBind, proxyPort)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
var cmdErr error
wg.Add(1)
go func() {
defer wg.Done()
command := fmt.Sprintf("go run . -port %d", args.AppPort)
if gzipEncoding {
command += " -gzip true"
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
cmdErr = generatecmd.Run(ctx, log, generatecmd.Arguments{
Path: appDir,
Watch: true,
OpenBrowser: false,
Command: command,
ProxyBind: proxyBind,
ProxyPort: proxyPort,
Proxy: args.AppURL,
NotifyProxy: false,
WorkerCount: 0,
GenerateSourceMapVisualisations: false,
IncludeVersion: false,
IncludeTimestamp: false,
PPROFPort: 0,
KeepOrphanedFiles: false,
})
}()
// Wait for server to start.
if err = waitForURL(args.AppURL); err != nil {
cancel()
wg.Wait()
return args, teardown, fmt.Errorf("failed to start app server, command error %v: %w", cmdErr, err)
}
if err = waitForURL(args.ProxyURL); err != nil {
cancel()
wg.Wait()
return args, teardown, fmt.Errorf("failed to start proxy server, command error %v: %w", cmdErr, err)
}
// Wait for exit.
teardown = func(t *testing.T) {
cancel()
wg.Wait()
if cmdErr != nil {
t.Errorf("failed to run generate cmd: %v", err)
}
if err = os.RemoveAll(appDir); err != nil {
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
}
}
return args, teardown, err
}
func waitForURL(url string) (err error) {
var tries int
for {
time.Sleep(time.Second)
if tries > 20 {
return err
}
tries++
var resp *http.Response
resp, err = http.Get(url)
if err != nil {
fmt.Printf("failed to get %q: %v\n", url, err)
continue
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("failed to get %q: %v\n", url, err)
err = fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
continue
}
return nil
}
}
func TestGenerateReturnsErrors(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatalf("could not find working dir: %v", err)
}
moduleRoot, err := modcheck.WalkUp(wd)
if err != nil {
t.Fatalf("could not find local templ go.mod file: %v", err)
}
appDir, err := createTestProject(moduleRoot)
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer func() {
if err = os.RemoveAll(appDir); err != nil {
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
}
}()
// Break the HTML.
templFile := filepath.Join(appDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification" -unclosed div-</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
// Run.
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
Path: appDir,
Watch: false,
IncludeVersion: false,
IncludeTimestamp: false,
KeepOrphanedFiles: false,
})
if err == nil {
t.Errorf("expected generation error, got %v", err)
}
}

View File

@@ -0,0 +1,7 @@
module templ/testproject
go 1.23
require github.com/a-h/templ v0.2.513 // indirect
replace github.com/a-h/templ => {moduleRoot}

View File

@@ -0,0 +1,2 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=

View File

@@ -0,0 +1,81 @@
package main
import (
"bytes"
"compress/gzip"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"strconv"
"github.com/a-h/templ"
)
type GzipResponseWriter struct {
w http.ResponseWriter
}
func (w *GzipResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
defer gzw.Close()
_, err := gzw.Write(b)
if err != nil {
return 0, err
}
err = gzw.Close()
if err != nil {
return 0, err
}
w.w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
return w.w.Write(buf.Bytes())
}
func (w *GzipResponseWriter) WriteHeader(statusCode int) {
w.w.WriteHeader(statusCode)
}
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
var useGzip = flag.Bool("gzip", false, "Toggle gzip encoding")
func main() {
flag.Parse()
if *flagPort == 0 {
fmt.Println("missing port flag")
os.Exit(1)
}
var count int
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if useGzip != nil && *useGzip {
w.Header().Set("Content-Encoding", "gzip")
w = &GzipResponseWriter{w: w}
}
count++
c := Page(count)
h := templ.Handler(c)
h.ErrorHandler = func(r *http.Request, err error) http.Handler {
slog.Error("failed to render template", slog.Any("error", err))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
})
}
h.ServeHTTP(w, r)
})
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
if err != nil {
fmt.Printf("Error listening: %v\n", err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,17 @@
package main
import "fmt"
templ Page(count int) {
<!DOCTYPE html>
<html>
<head>
<title>templ test page</title>
</head>
<body>
<h1>Count</h1>
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
<div data-testid="modification">Original</div>
</body>
</html>
}

View File

@@ -0,0 +1,55 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package main
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
import "fmt"
func Page(count int) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ`, Line: 13, Col: 54}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,166 @@
package watcher
import (
"context"
"io/fs"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
func Recursive(
ctx context.Context,
path string,
watchPattern *regexp.Regexp,
out chan fsnotify.Event,
errors chan error,
) (w *RecursiveWatcher, err error) {
fsnw, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors)
go w.loop()
return w, w.Add(path)
}
func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher {
return &RecursiveWatcher{
ctx: ctx,
w: w,
WatchPattern: watchPattern,
Events: events,
Errors: errors,
timers: make(map[timerKey]*time.Timer),
}
}
// WalkFiles walks the file tree rooted at path, sending a Create event for each
// file it encounters.
func WalkFiles(ctx context.Context, path string, watchPattern *regexp.Regexp, out chan fsnotify.Event) (err error) {
rootPath := path
fileSystem := os.DirFS(rootPath)
return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error {
if err != nil {
return nil
}
absPath, err := filepath.Abs(filepath.Join(rootPath, path))
if err != nil {
return nil
}
if info.IsDir() && shouldSkipDir(absPath) {
return filepath.SkipDir
}
if !watchPattern.MatchString(absPath) {
return nil
}
out <- fsnotify.Event{
Name: absPath,
Op: fsnotify.Create,
}
return nil
})
}
type RecursiveWatcher struct {
ctx context.Context
w *fsnotify.Watcher
WatchPattern *regexp.Regexp
Events chan fsnotify.Event
Errors chan error
timerMu sync.Mutex
timers map[timerKey]*time.Timer
}
type timerKey struct {
name string
op fsnotify.Op
}
func timerKeyFromEvent(event fsnotify.Event) timerKey {
return timerKey{
name: event.Name,
op: event.Op,
}
}
func (w *RecursiveWatcher) Close() error {
return w.w.Close()
}
func (w *RecursiveWatcher) loop() {
for {
select {
case <-w.ctx.Done():
return
case event, ok := <-w.w.Events:
if !ok {
return
}
if event.Has(fsnotify.Create) {
if err := w.Add(event.Name); err != nil {
w.Errors <- err
}
}
// Only notify on templ related files.
if !w.WatchPattern.MatchString(event.Name) {
continue
}
tk := timerKeyFromEvent(event)
w.timerMu.Lock()
t, ok := w.timers[tk]
w.timerMu.Unlock()
if !ok {
t = time.AfterFunc(100*time.Millisecond, func() {
w.Events <- event
})
w.timerMu.Lock()
w.timers[tk] = t
w.timerMu.Unlock()
continue
}
t.Reset(100 * time.Millisecond)
case err, ok := <-w.w.Errors:
if !ok {
return
}
w.Errors <- err
}
}
}
func (w *RecursiveWatcher) Add(dir string) error {
return filepath.WalkDir(dir, func(dir string, info os.DirEntry, err error) error {
if err != nil {
return nil
}
if !info.IsDir() {
return nil
}
if shouldSkipDir(dir) {
return filepath.SkipDir
}
return w.w.Add(dir)
})
}
func shouldSkipDir(dir string) bool {
if dir == "." {
return false
}
if dir == "vendor" || dir == "node_modules" {
return true
}
_, name := path.Split(dir)
// These directories are ignored by the Go tool.
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
return true
}
return false
}

View File

@@ -0,0 +1,133 @@
package watcher
import (
"context"
"fmt"
"regexp"
"testing"
"time"
"github.com/fsnotify/fsnotify"
)
func TestWatchDebouncesDuplicates(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(300 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 1 {
t.Errorf("expected 1 event, got %d", count)
}
return
}
}
}
func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
tests := []struct {
event1 fsnotify.Event
event2 fsnotify.Event
}{
// Different files
{fsnotify.Event{Name: "test.templ"}, fsnotify.Event{Name: "test2.templ"}},
// Different operations
{
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
fsnotify.Event{Name: "test.templ", Op: fsnotify.Write},
},
// Different operations and files
{
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
fsnotify.Event{Name: "test2.templ", Op: fsnotify.Write},
},
}
for _, test := range tests {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- test.event1
rw.w.Events <- test.event2
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(300 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 2 {
t.Errorf("expected 2 event, got %d", count)
}
return
}
}
}
}
func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
<-time.After(200 * time.Millisecond)
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(500 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 2 {
t.Errorf("expected 2 event, got %d", count)
}
return
}
}
}