Changed: DB Params
This commit is contained in:
403
templ/cmd/templ/generatecmd/cmd.go
Normal file
403
templ/cmd/templ/generatecmd/cmd.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/proxy"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/run"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/watcher"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/browser"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
const defaultWatchPattern = `(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)`
|
||||
|
||||
func NewGenerate(log *slog.Logger, args Arguments) (g *Generate, err error) {
|
||||
g = &Generate{
|
||||
Log: log,
|
||||
Args: &args,
|
||||
}
|
||||
if g.Args.WorkerCount == 0 {
|
||||
g.Args.WorkerCount = runtime.NumCPU()
|
||||
}
|
||||
if g.Args.WatchPattern == "" {
|
||||
g.Args.WatchPattern = defaultWatchPattern
|
||||
}
|
||||
g.WatchPattern, err = regexp.Compile(g.Args.WatchPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile watch pattern %q: %w", g.Args.WatchPattern, err)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
type Generate struct {
|
||||
Log *slog.Logger
|
||||
Args *Arguments
|
||||
WatchPattern *regexp.Regexp
|
||||
}
|
||||
|
||||
type GenerationEvent struct {
|
||||
Event fsnotify.Event
|
||||
Updated bool
|
||||
GoUpdated bool
|
||||
TextUpdated bool
|
||||
}
|
||||
|
||||
func (cmd Generate) Run(ctx context.Context) (err error) {
|
||||
if cmd.Args.NotifyProxy {
|
||||
return proxy.NotifyProxy(cmd.Args.ProxyBind, cmd.Args.ProxyPort)
|
||||
}
|
||||
if cmd.Args.Watch && cmd.Args.FileName != "" {
|
||||
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
|
||||
}
|
||||
writingToWriter := cmd.Args.FileWriter != nil
|
||||
if cmd.Args.FileName == "" && writingToWriter {
|
||||
return fmt.Errorf("only a single file can be output to stdout, add the -f flag to specify the file to generate code for")
|
||||
}
|
||||
// Default to writing to files.
|
||||
if cmd.Args.FileWriter == nil {
|
||||
cmd.Args.FileWriter = FileWriter
|
||||
}
|
||||
if cmd.Args.PPROFPort > 0 {
|
||||
go func() {
|
||||
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", cmd.Args.PPROFPort), nil)
|
||||
}()
|
||||
}
|
||||
|
||||
// Use absolute path.
|
||||
if !path.IsAbs(cmd.Args.Path) {
|
||||
cmd.Args.Path, err = filepath.Abs(cmd.Args.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Configure generator.
|
||||
var opts []generator.GenerateOpt
|
||||
if cmd.Args.IncludeVersion {
|
||||
opts = append(opts, generator.WithVersion(templ.Version()))
|
||||
}
|
||||
if cmd.Args.IncludeTimestamp {
|
||||
opts = append(opts, generator.WithTimestamp(time.Now()))
|
||||
}
|
||||
|
||||
// Check the version of the templ module.
|
||||
if err := modcheck.Check(cmd.Args.Path); err != nil {
|
||||
cmd.Log.Warn("templ version check: " + err.Error())
|
||||
}
|
||||
|
||||
fseh := NewFSEventHandler(
|
||||
cmd.Log,
|
||||
cmd.Args.Path,
|
||||
cmd.Args.Watch,
|
||||
opts,
|
||||
cmd.Args.GenerateSourceMapVisualisations,
|
||||
cmd.Args.KeepOrphanedFiles,
|
||||
cmd.Args.FileWriter,
|
||||
cmd.Args.Lazy,
|
||||
)
|
||||
|
||||
// If we're processing a single file, don't bother setting up the channels/multithreaing.
|
||||
if cmd.Args.FileName != "" {
|
||||
_, err = fseh.HandleEvent(ctx, fsnotify.Event{
|
||||
Name: cmd.Args.FileName,
|
||||
Op: fsnotify.Create,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Start timer.
|
||||
start := time.Now()
|
||||
|
||||
// Create channels:
|
||||
// For the initial filesystem walk and subsequent (optional) fsnotify events.
|
||||
events := make(chan fsnotify.Event)
|
||||
// Count of events currently being processed by the event handler.
|
||||
var eventsWG sync.WaitGroup
|
||||
// Used to check that the event handler has completed.
|
||||
var eventHandlerWG sync.WaitGroup
|
||||
// For errs from the watcher.
|
||||
errs := make(chan error)
|
||||
// Tracks whether errors occurred during the generation process.
|
||||
var errorCount atomic.Int64
|
||||
// For triggering actions after generation has completed.
|
||||
postGeneration := make(chan *GenerationEvent, 256)
|
||||
// Used to check that the post-generation handler has completed.
|
||||
var postGenerationWG sync.WaitGroup
|
||||
var postGenerationEventsWG sync.WaitGroup
|
||||
|
||||
// Waitgroup for the push process.
|
||||
var pushHandlerWG sync.WaitGroup
|
||||
|
||||
// Start process to push events into the channel.
|
||||
pushHandlerWG.Add(1)
|
||||
go func() {
|
||||
defer pushHandlerWG.Done()
|
||||
defer close(events)
|
||||
cmd.Log.Debug(
|
||||
"Walking directory",
|
||||
slog.String("path", cmd.Args.Path),
|
||||
slog.Bool("devMode", cmd.Args.Watch),
|
||||
)
|
||||
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil {
|
||||
cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
|
||||
return
|
||||
}
|
||||
if !cmd.Args.Watch {
|
||||
cmd.Log.Debug("Dev mode not enabled, process can finish early")
|
||||
return
|
||||
}
|
||||
cmd.Log.Info("Watching files")
|
||||
rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.WatchPattern, events, errs)
|
||||
if err != nil {
|
||||
cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)}
|
||||
return
|
||||
}
|
||||
cmd.Log.Debug("Waiting for context to be cancelled to stop watching files")
|
||||
<-ctx.Done()
|
||||
cmd.Log.Debug("Context cancelled, closing watcher")
|
||||
if err := rw.Close(); err != nil {
|
||||
cmd.Log.Error("Failed to close watcher", slog.Any("error", err))
|
||||
}
|
||||
cmd.Log.Debug("Waiting for events to be processed")
|
||||
eventsWG.Wait()
|
||||
cmd.Log.Debug(
|
||||
"All pending events processed, waiting for pending post-generation events to complete",
|
||||
)
|
||||
postGenerationEventsWG.Wait()
|
||||
cmd.Log.Debug(
|
||||
"All post-generation events processed, deleting watch mode text files",
|
||||
slog.Int64("errorCount", errorCount.Load()),
|
||||
)
|
||||
|
||||
fileEvents := make(chan fsnotify.Event)
|
||||
go func() {
|
||||
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, fileEvents); err != nil {
|
||||
cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
|
||||
return
|
||||
}
|
||||
close(fileEvents)
|
||||
}()
|
||||
for event := range fileEvents {
|
||||
if strings.HasSuffix(event.Name, "_templ.txt") {
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
cmd.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start process to handle events.
|
||||
eventHandlerWG.Add(1)
|
||||
sem := make(chan struct{}, cmd.Args.WorkerCount)
|
||||
go func() {
|
||||
defer eventHandlerWG.Done()
|
||||
defer close(postGeneration)
|
||||
cmd.Log.Debug("Starting event handler")
|
||||
for event := range events {
|
||||
eventsWG.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(event fsnotify.Event) {
|
||||
cmd.Log.Debug("Processing file", slog.String("file", event.Name))
|
||||
defer eventsWG.Done()
|
||||
defer func() { <-sem }()
|
||||
r, err := fseh.HandleEvent(ctx, event)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
if !(r.GoUpdated || r.TextUpdated) {
|
||||
cmd.Log.Debug("File not updated", slog.String("file", event.Name))
|
||||
return
|
||||
}
|
||||
e := &GenerationEvent{
|
||||
Event: event,
|
||||
Updated: r.Updated,
|
||||
GoUpdated: r.GoUpdated,
|
||||
TextUpdated: r.TextUpdated,
|
||||
}
|
||||
cmd.Log.Debug("File updated", slog.String("file", event.Name))
|
||||
postGeneration <- e
|
||||
}(event)
|
||||
}
|
||||
// Wait for all events to be processed before closing.
|
||||
eventsWG.Wait()
|
||||
}()
|
||||
|
||||
// Start process to handle post-generation events.
|
||||
var updates int
|
||||
postGenerationWG.Add(1)
|
||||
var firstPostGenerationExecuted bool
|
||||
go func() {
|
||||
defer close(errs)
|
||||
defer postGenerationWG.Done()
|
||||
cmd.Log.Debug("Starting post-generation handler")
|
||||
timeout := time.NewTimer(time.Hour * 24 * 365)
|
||||
var goUpdated, textUpdated bool
|
||||
var p *proxy.Handler
|
||||
for {
|
||||
select {
|
||||
case ge := <-postGeneration:
|
||||
if ge == nil {
|
||||
cmd.Log.Debug("Post-generation event channel closed, exiting")
|
||||
return
|
||||
}
|
||||
goUpdated = goUpdated || ge.GoUpdated
|
||||
textUpdated = textUpdated || ge.TextUpdated
|
||||
if goUpdated || textUpdated {
|
||||
updates++
|
||||
}
|
||||
// Reset timer.
|
||||
if !timeout.Stop() {
|
||||
<-timeout.C
|
||||
}
|
||||
timeout.Reset(time.Millisecond * 100)
|
||||
case <-timeout.C:
|
||||
if !goUpdated && !textUpdated {
|
||||
// Nothing to process, reset timer and wait again.
|
||||
timeout.Reset(time.Hour * 24 * 365)
|
||||
break
|
||||
}
|
||||
postGenerationEventsWG.Add(1)
|
||||
if cmd.Args.Command != "" && goUpdated {
|
||||
cmd.Log.Debug("Executing command", slog.String("command", cmd.Args.Command))
|
||||
if cmd.Args.Watch {
|
||||
os.Setenv("TEMPL_DEV_MODE", "true")
|
||||
}
|
||||
if _, err := run.Run(ctx, cmd.Args.Path, cmd.Args.Command); err != nil {
|
||||
cmd.Log.Error("Error executing command", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if !firstPostGenerationExecuted {
|
||||
cmd.Log.Debug("First post-generation event received, starting proxy")
|
||||
firstPostGenerationExecuted = true
|
||||
p, err = cmd.StartProxy(ctx)
|
||||
if err != nil {
|
||||
cmd.Log.Error("Failed to start proxy", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
// Send server-sent event.
|
||||
if p != nil && (textUpdated || goUpdated) {
|
||||
cmd.Log.Debug("Sending reload event")
|
||||
p.SendSSE("message", "reload")
|
||||
}
|
||||
postGenerationEventsWG.Done()
|
||||
// Reset timer.
|
||||
timeout.Reset(time.Millisecond * 100)
|
||||
textUpdated = false
|
||||
goUpdated = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read errors.
|
||||
for err := range errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, FatalError{}) {
|
||||
cmd.Log.Debug("Fatal error, exiting")
|
||||
return err
|
||||
}
|
||||
cmd.Log.Error("Error", slog.Any("error", err))
|
||||
errorCount.Add(1)
|
||||
}
|
||||
|
||||
// Wait for everything to complete.
|
||||
cmd.Log.Debug("Waiting for push handler to complete")
|
||||
pushHandlerWG.Wait()
|
||||
cmd.Log.Debug("Waiting for event handler to complete")
|
||||
eventHandlerWG.Wait()
|
||||
cmd.Log.Debug("Waiting for post-generation handler to complete")
|
||||
postGenerationWG.Wait()
|
||||
if cmd.Args.Command != "" {
|
||||
cmd.Log.Debug("Killing command", slog.String("command", cmd.Args.Command))
|
||||
if err := run.KillAll(); err != nil {
|
||||
cmd.Log.Error("Error killing command", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors after everything has completed.
|
||||
if errorCount.Load() > 0 {
|
||||
return fmt.Errorf("generation completed with %d errors", errorCount.Load())
|
||||
}
|
||||
|
||||
cmd.Log.Info(
|
||||
"Complete",
|
||||
slog.Int("updates", updates),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err error) {
|
||||
if cmd.Args.Proxy == "" {
|
||||
cmd.Log.Debug("No proxy URL specified, not starting proxy")
|
||||
return nil, nil
|
||||
}
|
||||
var target *url.URL
|
||||
target, err = url.Parse(cmd.Args.Proxy)
|
||||
if err != nil {
|
||||
return nil, FatalError{Err: fmt.Errorf("failed to parse proxy URL: %w", err)}
|
||||
}
|
||||
if cmd.Args.ProxyPort == 0 {
|
||||
cmd.Args.ProxyPort = 7331
|
||||
}
|
||||
if cmd.Args.ProxyBind == "" {
|
||||
cmd.Args.ProxyBind = "127.0.0.1"
|
||||
}
|
||||
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
|
||||
go func() {
|
||||
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
|
||||
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
|
||||
cmd.Log.Error("Proxy failed", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
if !cmd.Args.OpenBrowser {
|
||||
cmd.Log.Debug("Not opening browser")
|
||||
return p, nil
|
||||
}
|
||||
go func() {
|
||||
cmd.Log.Debug("Waiting for proxy to be ready", slog.String("url", p.URL))
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.InitialInterval = time.Second
|
||||
var client http.Client
|
||||
client.Timeout = 1 * time.Second
|
||||
for {
|
||||
if _, err := client.Get(p.URL); err == nil {
|
||||
break
|
||||
}
|
||||
d := backoff.NextBackOff()
|
||||
cmd.Log.Debug(
|
||||
"Proxy not ready, retrying",
|
||||
slog.String("url", p.URL),
|
||||
slog.Any("backoff", d),
|
||||
)
|
||||
time.Sleep(d)
|
||||
}
|
||||
if err := browser.OpenURL(p.URL); err != nil {
|
||||
cmd.Log.Error("Failed to open browser", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
return p, nil
|
||||
}
|
366
templ/cmd/templ/generatecmd/eventhandler.go
Normal file
366
templ/cmd/templ/generatecmd/eventhandler.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/visualize"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
type FileWriterFunc func(name string, contents []byte) error
|
||||
|
||||
func FileWriter(fileName string, contents []byte) error {
|
||||
return os.WriteFile(fileName, contents, 0o644)
|
||||
}
|
||||
|
||||
func WriterFileWriter(w io.Writer) FileWriterFunc {
|
||||
return func(_ string, contents []byte) error {
|
||||
_, err := w.Write(contents)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func NewFSEventHandler(
|
||||
log *slog.Logger,
|
||||
dir string,
|
||||
devMode bool,
|
||||
genOpts []generator.GenerateOpt,
|
||||
genSourceMapVis bool,
|
||||
keepOrphanedFiles bool,
|
||||
fileWriter FileWriterFunc,
|
||||
lazy bool,
|
||||
) *FSEventHandler {
|
||||
if !path.IsAbs(dir) {
|
||||
dir, _ = filepath.Abs(dir)
|
||||
}
|
||||
fseh := &FSEventHandler{
|
||||
Log: log,
|
||||
dir: dir,
|
||||
fileNameToLastModTime: make(map[string]time.Time),
|
||||
fileNameToLastModTimeMutex: &sync.Mutex{},
|
||||
fileNameToError: make(map[string]struct{}),
|
||||
fileNameToErrorMutex: &sync.Mutex{},
|
||||
fileNameToOutput: make(map[string]generator.GeneratorOutput),
|
||||
fileNameToOutputMutex: &sync.Mutex{},
|
||||
devMode: devMode,
|
||||
hashes: make(map[string][sha256.Size]byte),
|
||||
hashesMutex: &sync.Mutex{},
|
||||
genOpts: genOpts,
|
||||
genSourceMapVis: genSourceMapVis,
|
||||
keepOrphanedFiles: keepOrphanedFiles,
|
||||
writer: fileWriter,
|
||||
lazy: lazy,
|
||||
}
|
||||
return fseh
|
||||
}
|
||||
|
||||
type FSEventHandler struct {
|
||||
Log *slog.Logger
|
||||
// dir is the root directory being processed.
|
||||
dir string
|
||||
fileNameToLastModTime map[string]time.Time
|
||||
fileNameToLastModTimeMutex *sync.Mutex
|
||||
fileNameToError map[string]struct{}
|
||||
fileNameToErrorMutex *sync.Mutex
|
||||
fileNameToOutput map[string]generator.GeneratorOutput
|
||||
fileNameToOutputMutex *sync.Mutex
|
||||
devMode bool
|
||||
hashes map[string][sha256.Size]byte
|
||||
hashesMutex *sync.Mutex
|
||||
genOpts []generator.GenerateOpt
|
||||
genSourceMapVis bool
|
||||
Errors []error
|
||||
keepOrphanedFiles bool
|
||||
writer func(string, []byte) error
|
||||
lazy bool
|
||||
}
|
||||
|
||||
type GenerateResult struct {
|
||||
// Updated indicates that the file was updated.
|
||||
Updated bool
|
||||
// GoUpdated indicates that Go expressions were updated.
|
||||
GoUpdated bool
|
||||
// TextUpdated indicates that text literals were updated.
|
||||
TextUpdated bool
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) HandleEvent(ctx context.Context, event fsnotify.Event) (result GenerateResult, err error) {
|
||||
// Handle _templ.go files.
|
||||
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.go") {
|
||||
_, err = os.Stat(strings.TrimSuffix(event.Name, "_templ.go") + ".templ")
|
||||
if !os.IsNotExist(err) {
|
||||
return GenerateResult{}, err
|
||||
}
|
||||
// File is orphaned.
|
||||
if h.keepOrphanedFiles {
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
h.Log.Debug("Deleting orphaned Go file", slog.String("file", event.Name))
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
h.Log.Warn("Failed to remove orphaned file", slog.Any("error", err))
|
||||
}
|
||||
return GenerateResult{Updated: true, GoUpdated: true, TextUpdated: false}, nil
|
||||
}
|
||||
// Handle _templ.txt files.
|
||||
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.txt") {
|
||||
if h.devMode {
|
||||
// Don't delete the file in dev mode, ignore changes to it, since the .templ file
|
||||
// must have been updated in order to trigger a change in the _templ.txt file.
|
||||
return GenerateResult{Updated: false, GoUpdated: false, TextUpdated: false}, nil
|
||||
}
|
||||
h.Log.Debug("Deleting watch mode file", slog.String("file", event.Name))
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
h.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// If the file hasn't been updated since the last time we processed it, ignore it.
|
||||
lastModTime, updatedModTime := h.UpsertLastModTime(event.Name)
|
||||
if !updatedModTime {
|
||||
h.Log.Debug("Skipping file because it wasn't updated", slog.String("file", event.Name))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// Process anything that isn't a templ file.
|
||||
if !strings.HasSuffix(event.Name, ".templ") {
|
||||
// If it's a Go file, mark it as updated.
|
||||
if strings.HasSuffix(event.Name, ".go") {
|
||||
result.GoUpdated = true
|
||||
}
|
||||
result.Updated = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle templ files.
|
||||
|
||||
// If the go file is newer than the templ file, skip generation, because it's up-to-date.
|
||||
if h.lazy && goFileIsUpToDate(event.Name, lastModTime) {
|
||||
h.Log.Debug("Skipping file because the Go file is up-to-date", slog.String("file", event.Name))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// Start a processor.
|
||||
start := time.Now()
|
||||
var diag []parser.Diagnostic
|
||||
result, diag, err = h.generate(ctx, event.Name)
|
||||
if err != nil {
|
||||
h.SetError(event.Name, true)
|
||||
return result, fmt.Errorf("failed to generate code for %q: %w", event.Name, err)
|
||||
}
|
||||
if len(diag) > 0 {
|
||||
for _, d := range diag {
|
||||
h.Log.Warn(d.Message,
|
||||
slog.String("from", fmt.Sprintf("%d:%d", d.Range.From.Line, d.Range.From.Col)),
|
||||
slog.String("to", fmt.Sprintf("%d:%d", d.Range.To.Line, d.Range.To.Col)),
|
||||
)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
if errorCleared, errorCount := h.SetError(event.Name, false); errorCleared {
|
||||
h.Log.Info("Error cleared", slog.String("file", event.Name), slog.Int("errors", errorCount))
|
||||
}
|
||||
h.Log.Debug("Generated code", slog.String("file", event.Name), slog.Duration("in", time.Since(start)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func goFileIsUpToDate(templFileName string, templFileLastMod time.Time) (upToDate bool) {
|
||||
goFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ.go"
|
||||
goFileInfo, err := os.Stat(goFileName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return goFileInfo.ModTime().After(templFileLastMod)
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) SetError(fileName string, hasError bool) (previouslyHadError bool, errorCount int) {
|
||||
h.fileNameToErrorMutex.Lock()
|
||||
defer h.fileNameToErrorMutex.Unlock()
|
||||
_, previouslyHadError = h.fileNameToError[fileName]
|
||||
delete(h.fileNameToError, fileName)
|
||||
if hasError {
|
||||
h.fileNameToError[fileName] = struct{}{}
|
||||
}
|
||||
return previouslyHadError, len(h.fileNameToError)
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) UpsertLastModTime(fileName string) (modTime time.Time, updated bool) {
|
||||
fileInfo, err := os.Stat(fileName)
|
||||
if err != nil {
|
||||
return modTime, false
|
||||
}
|
||||
h.fileNameToLastModTimeMutex.Lock()
|
||||
defer h.fileNameToLastModTimeMutex.Unlock()
|
||||
previousModTime := h.fileNameToLastModTime[fileName]
|
||||
currentModTime := fileInfo.ModTime()
|
||||
if !currentModTime.After(previousModTime) {
|
||||
return currentModTime, false
|
||||
}
|
||||
h.fileNameToLastModTime[fileName] = currentModTime
|
||||
return currentModTime, true
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) UpsertHash(fileName string, hash [sha256.Size]byte) (updated bool) {
|
||||
h.hashesMutex.Lock()
|
||||
defer h.hashesMutex.Unlock()
|
||||
lastHash := h.hashes[fileName]
|
||||
if lastHash == hash {
|
||||
return false
|
||||
}
|
||||
h.hashes[fileName] = hash
|
||||
return true
|
||||
}
|
||||
|
||||
// generate Go code for a single template.
|
||||
// If a basePath is provided, the filename included in error messages is relative to it.
|
||||
func (h *FSEventHandler) generate(ctx context.Context, fileName string) (result GenerateResult, diagnostics []parser.Diagnostic, err error) {
|
||||
t, err := parser.Parse(fileName)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s parsing error: %w", fileName, err)
|
||||
}
|
||||
targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go"
|
||||
|
||||
// Only use relative filenames to the basepath for filenames in runtime error messages.
|
||||
absFilePath, err := filepath.Abs(fileName)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("failed to get absolute path for %q: %w", fileName, err)
|
||||
}
|
||||
relFilePath, err := filepath.Rel(h.dir, absFilePath)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("failed to get relative path for %q: %w", fileName, err)
|
||||
}
|
||||
// Convert Windows file paths to Unix-style for consistency.
|
||||
relFilePath = filepath.ToSlash(relFilePath)
|
||||
|
||||
var b bytes.Buffer
|
||||
generatorOutput, err := generator.Generate(t, &b, append(h.genOpts, generator.WithFileName(relFilePath))...)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s generation error: %w", fileName, err)
|
||||
}
|
||||
|
||||
formattedGoCode, err := format.Source(b.Bytes())
|
||||
if err != nil {
|
||||
err = remapErrorList(err, generatorOutput.SourceMap, fileName)
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s source formatting error %w", fileName, err)
|
||||
}
|
||||
|
||||
// Hash output, and write out the file if the goCodeHash has changed.
|
||||
goCodeHash := sha256.Sum256(formattedGoCode)
|
||||
if h.UpsertHash(targetFileName, goCodeHash) {
|
||||
result.Updated = true
|
||||
if err = h.writer(targetFileName, formattedGoCode); err != nil {
|
||||
return result, nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the txt file if it has changed.
|
||||
if h.devMode {
|
||||
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
|
||||
joined := strings.Join(generatorOutput.Literals, "\n")
|
||||
txtHash := sha256.Sum256([]byte(joined))
|
||||
if h.UpsertHash(txtFileName, txtHash) {
|
||||
result.TextUpdated = true
|
||||
if err = os.WriteFile(txtFileName, []byte(joined), 0o644); err != nil {
|
||||
return result, nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether the change would require a recompilation to take effect.
|
||||
h.fileNameToOutputMutex.Lock()
|
||||
defer h.fileNameToOutputMutex.Unlock()
|
||||
previous := h.fileNameToOutput[fileName]
|
||||
if generator.HasChanged(previous, generatorOutput) {
|
||||
result.GoUpdated = true
|
||||
}
|
||||
h.fileNameToOutput[fileName] = generatorOutput
|
||||
}
|
||||
|
||||
parsedDiagnostics, err := parser.Diagnose(t)
|
||||
if err != nil {
|
||||
return result, nil, fmt.Errorf("%s diagnostics error: %w", fileName, err)
|
||||
}
|
||||
|
||||
if h.genSourceMapVis {
|
||||
err = generateSourceMapVisualisation(ctx, fileName, targetFileName, generatorOutput.SourceMap)
|
||||
}
|
||||
|
||||
return result, parsedDiagnostics, err
|
||||
}
|
||||
|
||||
// Takes an error from the formatter and attempts to convert the positions reported in the target file to their positions
|
||||
// in the source file.
|
||||
func remapErrorList(err error, sourceMap *parser.SourceMap, fileName string) error {
|
||||
list, ok := err.(scanner.ErrorList)
|
||||
if !ok || len(list) == 0 {
|
||||
return err
|
||||
}
|
||||
for i, e := range list {
|
||||
// The positions in the source map are off by one line because of the package definition.
|
||||
srcPos, ok := sourceMap.SourcePositionFromTarget(uint32(e.Pos.Line-1), uint32(e.Pos.Column))
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
list[i].Pos = token.Position{
|
||||
Filename: fileName,
|
||||
Offset: int(srcPos.Index),
|
||||
Line: int(srcPos.Line) + 1,
|
||||
Column: int(srcPos.Col),
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var templContents, goContents []byte
|
||||
var templErr, goErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
templContents, templErr = os.ReadFile(templFileName)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
goContents, goErr = os.ReadFile(goFileName)
|
||||
}()
|
||||
wg.Wait()
|
||||
if templErr != nil {
|
||||
return templErr
|
||||
}
|
||||
if goErr != nil {
|
||||
return templErr
|
||||
}
|
||||
|
||||
targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html"
|
||||
w, err := os.Create(targetFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err)
|
||||
}
|
||||
defer w.Close()
|
||||
b := bufio.NewWriter(w)
|
||||
defer b.Flush()
|
||||
|
||||
return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b)
|
||||
}
|
23
templ/cmd/templ/generatecmd/fatalerror.go
Normal file
23
templ/cmd/templ/generatecmd/fatalerror.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package generatecmd
|
||||
|
||||
type FatalError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e FatalError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e FatalError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e FatalError) Is(target error) bool {
|
||||
_, ok := target.(FatalError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e FatalError) As(target any) bool {
|
||||
_, ok := target.(*FatalError)
|
||||
return ok
|
||||
}
|
39
templ/cmd/templ/generatecmd/main.go
Normal file
39
templ/cmd/templ/generatecmd/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"log/slog"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
FileName string
|
||||
FileWriter FileWriterFunc
|
||||
Path string
|
||||
Watch bool
|
||||
WatchPattern string
|
||||
OpenBrowser bool
|
||||
Command string
|
||||
ProxyBind string
|
||||
ProxyPort int
|
||||
Proxy string
|
||||
NotifyProxy bool
|
||||
WorkerCount int
|
||||
GenerateSourceMapVisualisations bool
|
||||
IncludeVersion bool
|
||||
IncludeTimestamp bool
|
||||
// PPROFPort is the port to run the pprof server on.
|
||||
PPROFPort int
|
||||
KeepOrphanedFiles bool
|
||||
Lazy bool
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, log *slog.Logger, args Arguments) (err error) {
|
||||
g, err := NewGenerate(log, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return g.Run(ctx)
|
||||
}
|
170
templ/cmd/templ/generatecmd/main_test.go
Normal file
170
templ/cmd/templ/generatecmd/main_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
t.Run("can generate a file in place", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
|
||||
// Run the generate command.
|
||||
err = Run(context.Background(), log, Arguments{
|
||||
FileName: path.Join(dir, "templates.templ"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run generate command: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.go file was created.
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("templates_templ.go was not created: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can generate a file in watch mode", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
// Run the generate command.
|
||||
return Run(ctx, log, Arguments{
|
||||
Path: dir,
|
||||
Watch: true,
|
||||
})
|
||||
})
|
||||
|
||||
// Check the templates_templ.go file was created, with backoff.
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(time.Second * time.Duration(i))
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("template files were not created: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
if err := eg.Wait(); err != nil {
|
||||
t.Fatalf("generate command failed: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.txt file was removed.
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
|
||||
if err == nil {
|
||||
t.Fatalf("templates_templ.txt was not removed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultWatchPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
matches bool
|
||||
}{
|
||||
{
|
||||
name: "empty file names do not match",
|
||||
input: "",
|
||||
matches: false,
|
||||
},
|
||||
{
|
||||
name: "*_templ.txt matches, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\strings_templ.txt`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.txt matches, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/strings_templ.txt",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.templ files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.templ`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.templ files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.templ",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.go files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates_templ.go`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.go files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates_templ.go",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.go files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.go`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.go files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.go",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.css files do not match",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.css",
|
||||
matches: false,
|
||||
},
|
||||
}
|
||||
wpRegexp, err := regexp.Compile(defaultWatchPattern)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile default watch pattern: %v", err)
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if wpRegexp.MatchString(test.input) != test.matches {
|
||||
t.Fatalf("expected match of %q to be %v", test.input, test.matches)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
82
templ/cmd/templ/generatecmd/modcheck/modcheck.go
Normal file
82
templ/cmd/templ/generatecmd/modcheck/modcheck.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package modcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"golang.org/x/mod/modfile"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// WalkUp the directory tree, starting at dir, until we find a directory containing
|
||||
// a go.mod file.
|
||||
func WalkUp(dir string) (string, error) {
|
||||
dir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
var modFile string
|
||||
for {
|
||||
modFile = filepath.Join(dir, "go.mod")
|
||||
_, err := os.Stat(modFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("failed to stat go.mod file: %w", err)
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
// Move up.
|
||||
prev := dir
|
||||
dir = filepath.Dir(dir)
|
||||
if dir == prev {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// No file found.
|
||||
if modFile == "" {
|
||||
return dir, fmt.Errorf("could not find go.mod file")
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func Check(dir string) error {
|
||||
dir, err := WalkUp(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Found a go.mod file.
|
||||
// Read it and find the templ version.
|
||||
modFile := filepath.Join(dir, "go.mod")
|
||||
m, err := os.ReadFile(modFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read go.mod file: %w", err)
|
||||
}
|
||||
|
||||
mf, err := modfile.Parse(modFile, m, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse go.mod file: %w", err)
|
||||
}
|
||||
if mf.Module.Mod.Path == "github.com/a-h/templ" {
|
||||
// The go.mod file is for templ itself.
|
||||
return nil
|
||||
}
|
||||
for _, r := range mf.Require {
|
||||
if r.Mod.Path == "github.com/a-h/templ" {
|
||||
cmp := semver.Compare(r.Mod.Version, templ.Version())
|
||||
if cmp < 0 {
|
||||
return fmt.Errorf("generator %v is newer than templ version %v found in go.mod file, consider running `go get -u github.com/a-h/templ` to upgrade", templ.Version(), r.Mod.Version)
|
||||
}
|
||||
if cmp > 0 {
|
||||
return fmt.Errorf("generator %v is older than templ version %v found in go.mod file, consider upgrading templ CLI", templ.Version(), r.Mod.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("templ not found in go.mod file, run `go get github.com/a-h/templ` to install it")
|
||||
}
|
47
templ/cmd/templ/generatecmd/modcheck/modcheck_test.go
Normal file
47
templ/cmd/templ/generatecmd/modcheck/modcheck_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package modcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/mod/modfile"
|
||||
)
|
||||
|
||||
func TestPatchGoVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "go 1.20",
|
||||
expected: "1.20",
|
||||
},
|
||||
{
|
||||
input: "go 1.20.123",
|
||||
expected: "1.20.123",
|
||||
},
|
||||
{
|
||||
input: "go 1.20.1",
|
||||
expected: "1.20.1",
|
||||
},
|
||||
{
|
||||
input: "go 1.20rc1",
|
||||
expected: "1.20rc1",
|
||||
},
|
||||
{
|
||||
input: "go 1.15",
|
||||
expected: "1.15",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
input := "module github.com/a-h/templ\n\n" + string(test.input) + "\n" + "toolchain go1.27.9\n"
|
||||
mf, err := modfile.Parse("go.mod", []byte(input), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse go.mod: %v", err)
|
||||
}
|
||||
if test.expected != mf.Go.Version {
|
||||
t.Errorf("expected %q, got %q", test.expected, mf.Go.Version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
284
templ/cmd/templ/generatecmd/proxy/proxy.go
Normal file
284
templ/cmd/templ/generatecmd/proxy/proxy.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
stdlog "log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
|
||||
"github.com/andybalholm/brotli"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed script.js
|
||||
var script string
|
||||
|
||||
type Handler struct {
|
||||
log *slog.Logger
|
||||
URL string
|
||||
Target *url.URL
|
||||
p *httputil.ReverseProxy
|
||||
sse *sse.Handler
|
||||
}
|
||||
|
||||
func getScriptTag(nonce string) string {
|
||||
if nonce != "" {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<script src="/_templ/reload/script.js" nonce="`)
|
||||
sb.WriteString(html.EscapeString(nonce))
|
||||
sb.WriteString(`"></script>`)
|
||||
return sb.String()
|
||||
}
|
||||
return `<script src="/_templ/reload/script.js"></script>`
|
||||
}
|
||||
|
||||
func insertScriptTagIntoBody(nonce, body string) (updated string) {
|
||||
doc, err := goquery.NewDocumentFromReader(strings.NewReader(body))
|
||||
if err != nil {
|
||||
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
|
||||
}
|
||||
doc.Find("body").AppendHtml(getScriptTag(nonce))
|
||||
r, err := doc.Html()
|
||||
if err != nil {
|
||||
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type passthroughWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (pwc passthroughWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."
|
||||
|
||||
func (h *Handler) modifyResponse(r *http.Response) error {
|
||||
log := h.log.With(slog.String("url", r.Request.URL.String()))
|
||||
if r.Header.Get("templ-skip-modify") == "true" {
|
||||
log.Debug("Skipping response modification because templ-skip-modify header is set")
|
||||
return nil
|
||||
}
|
||||
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
|
||||
log.Debug("Skipping response modification because content type is not text/html", slog.String("content-type", contentType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up readers and writers.
|
||||
newReader := func(in io.Reader) (out io.Reader, err error) {
|
||||
return in, nil
|
||||
}
|
||||
newWriter := func(out io.Writer) io.WriteCloser {
|
||||
return passthroughWriteCloser{out}
|
||||
}
|
||||
switch r.Header.Get("Content-Encoding") {
|
||||
case "gzip":
|
||||
newReader = func(in io.Reader) (out io.Reader, err error) {
|
||||
return gzip.NewReader(in)
|
||||
}
|
||||
newWriter = func(out io.Writer) io.WriteCloser {
|
||||
return gzip.NewWriter(out)
|
||||
}
|
||||
case "br":
|
||||
newReader = func(in io.Reader) (out io.Reader, err error) {
|
||||
return brotli.NewReader(in), nil
|
||||
}
|
||||
newWriter = func(out io.Writer) io.WriteCloser {
|
||||
return brotli.NewWriter(out)
|
||||
}
|
||||
case "":
|
||||
log.Debug("No content encoding header found")
|
||||
default:
|
||||
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
|
||||
}
|
||||
|
||||
// Read the encoded body.
|
||||
encr, err := newReader(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Body.Close()
|
||||
body, err := io.ReadAll(encr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update it.
|
||||
csp := r.Header.Get("Content-Security-Policy")
|
||||
updated := insertScriptTagIntoBody(parseNonce(csp), string(body))
|
||||
if log.Enabled(r.Request.Context(), slog.LevelDebug) {
|
||||
if len(updated) == len(body) {
|
||||
log.Debug("Reload script not inserted")
|
||||
} else {
|
||||
log.Debug("Reload script inserted")
|
||||
}
|
||||
}
|
||||
|
||||
// Encode the response.
|
||||
var buf bytes.Buffer
|
||||
encw := newWriter(&buf)
|
||||
_, err = encw.Write([]byte(updated))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = encw.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the response.
|
||||
r.Body = io.NopCloser(&buf)
|
||||
r.ContentLength = int64(buf.Len())
|
||||
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseNonce(csp string) (nonce string) {
|
||||
outer:
|
||||
for _, rawDirective := range strings.Split(csp, ";") {
|
||||
parts := strings.Fields(rawDirective)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
if parts[0] != "script-src" {
|
||||
continue
|
||||
}
|
||||
for _, source := range parts[1:] {
|
||||
source = strings.TrimPrefix(source, "'")
|
||||
source = strings.TrimSuffix(source, "'")
|
||||
if strings.HasPrefix(source, "nonce-") {
|
||||
nonce = source[6:]
|
||||
break outer
|
||||
}
|
||||
}
|
||||
}
|
||||
return nonce
|
||||
}
|
||||
|
||||
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
|
||||
p := httputil.NewSingleHostReverseProxy(target)
|
||||
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
|
||||
p.Transport = &roundTripper{
|
||||
maxRetries: 20,
|
||||
initialDelay: 100 * time.Millisecond,
|
||||
backoffExponent: 1.5,
|
||||
}
|
||||
h = &Handler{
|
||||
log: log,
|
||||
URL: fmt.Sprintf("http://%s:%d", bind, port),
|
||||
Target: target,
|
||||
p: p,
|
||||
sse: sse.New(),
|
||||
}
|
||||
p.ModifyResponse = h.modifyResponse
|
||||
return h
|
||||
}
|
||||
|
||||
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/_templ/reload/script.js" {
|
||||
// Provides a script that reloads the page.
|
||||
w.Header().Add("Content-Type", "text/javascript")
|
||||
_, err := io.WriteString(w, script)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to write script: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/_templ/reload/events" {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// Provides a list of messages including a reload message.
|
||||
p.sse.ServeHTTP(w, r)
|
||||
return
|
||||
case http.MethodPost:
|
||||
// Send a reload message to all connected clients.
|
||||
p.sse.Send("message", "reload")
|
||||
return
|
||||
}
|
||||
http.Error(w, "only GET or POST method allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
p.p.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *Handler) SendSSE(eventType string, data string) {
|
||||
p.sse.Send(eventType, data)
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
maxRetries int
|
||||
initialDelay time.Duration
|
||||
backoffExponent float64
|
||||
}
|
||||
|
||||
func (rt *roundTripper) setShouldSkipResponseModificationHeader(r *http.Request, resp *http.Response) {
|
||||
// Instruct the modifyResponse function to skip modifying the response if the
|
||||
// HTTP request has come from HTMX.
|
||||
if r.Header.Get("HX-Request") != "true" {
|
||||
return
|
||||
}
|
||||
resp.Header.Set("templ-skip-modify", "true")
|
||||
}
|
||||
|
||||
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// Read and buffer the body.
|
||||
var bodyBytes []byte
|
||||
if r.Body != nil && r.Body != http.NoBody {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Body.Close()
|
||||
}
|
||||
|
||||
// Retry logic.
|
||||
var resp *http.Response
|
||||
var err error
|
||||
for retries := 0; retries < rt.maxRetries; retries++ {
|
||||
// Clone the request and set the body.
|
||||
req := r.Clone(r.Context())
|
||||
if bodyBytes != nil {
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Execute the request.
|
||||
resp, err = http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
time.Sleep(rt.initialDelay * time.Duration(math.Pow(rt.backoffExponent, float64(retries))))
|
||||
continue
|
||||
}
|
||||
|
||||
rt.setShouldSkipResponseModificationHeader(r, resp)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("max retries reached: %q", r.URL.String())
|
||||
}
|
||||
|
||||
func NotifyProxy(host string, port int) error {
|
||||
urlStr := fmt.Sprintf("http://%s:%d/_templ/reload/events", host, port)
|
||||
req, err := http.NewRequest(http.MethodPost, urlStr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = http.DefaultClient.Do(req)
|
||||
return err
|
||||
}
|
627
templ/cmd/templ/generatecmd/proxy/proxy_test.go
Normal file
627
templ/cmd/templ/generatecmd/proxy/proxy_test.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestRoundTripper(t *testing.T) {
|
||||
t.Run("if the HX-Request header is present, set the templ-skip-modify header on the response", func(t *testing.T) {
|
||||
rt := &roundTripper{}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating request: %v", err)
|
||||
}
|
||||
req.Header.Set("HX-Request", "true")
|
||||
resp := &http.Response{Header: make(http.Header)}
|
||||
rt.setShouldSkipResponseModificationHeader(req, resp)
|
||||
if resp.Header.Get("templ-skip-modify") != "true" {
|
||||
t.Errorf("expected templ-skip-modify header to be true, got %v", resp.Header.Get("templ-skip-modify"))
|
||||
}
|
||||
})
|
||||
t.Run("if the HX-Request header is not present, do not set the templ-skip-modify header on the response", func(t *testing.T) {
|
||||
rt := &roundTripper{}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating request: %v", err)
|
||||
}
|
||||
resp := &http.Response{Header: make(http.Header)}
|
||||
rt.setShouldSkipResponseModificationHeader(req, resp)
|
||||
if resp.Header.Get("templ-skip-modify") != "" {
|
||||
t.Errorf("expected templ-skip-modify header to be empty, got %v", resp.Header.Get("templ-skip-modify"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
t.Run("plain: non-html content is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Content-Length", "16")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "16" {
|
||||
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: if the response contains templ-skip-modify header, it is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`Hello`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html")
|
||||
r.Header.Set("Content-Length", "5")
|
||||
r.Header.Set("templ-skip-modify", "true")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "5" {
|
||||
t.Errorf("expected content length to be 5, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`Hello`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", `<html><body></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag("")) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted with nonce", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
const nonce = "this-is-the-nonce"
|
||||
r.Header.Set("Content-Security-Policy", fmt.Sprintf("script-src 'nonce-%s'", nonce))
|
||||
|
||||
expectedString := insertScriptTagIntoBody(nonce, `<html><body></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag(nonce)) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body><script>console.log("<body></body>")</script></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", `<html><body><script>console.log("<body></body>")</script></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag("")) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
if !strings.Contains(expectedString, `console.log("<body></body>")`) {
|
||||
t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("gzip: non-html content is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
// It's not actually gzipped here, but it doesn't matter, it shouldn't get that far.
|
||||
r.Header.Set("Content-Encoding", "gzip")
|
||||
// Similarly, this is not the actual length of the gzipped content.
|
||||
r.Header.Set("Content-Length", "16")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "16" {
|
||||
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("gzip: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
body := `<html><body></body></html>`
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
_, err := gzw.Write([]byte(body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
gzw.Close()
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", body)
|
||||
|
||||
var expectedBytes bytes.Buffer
|
||||
gzw = gzip.NewWriter(&expectedBytes)
|
||||
_, err = gzw.Write([]byte(expectedString))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
gzw.Close()
|
||||
expectedLength := len(expectedBytes.Bytes())
|
||||
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(&buf),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "gzip")
|
||||
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err = h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
|
||||
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
actualBody, err := io.ReadAll(gr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
body := `<html><body></body></html>`
|
||||
var buf bytes.Buffer
|
||||
brw := brotli.NewWriter(&buf)
|
||||
_, err := brw.Write([]byte(body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
brw.Close()
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", body)
|
||||
|
||||
var expectedBytes bytes.Buffer
|
||||
brw = brotli.NewWriter(&expectedBytes)
|
||||
_, err = brw.Write([]byte(expectedString))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
brw.Close()
|
||||
expectedLength := len(expectedBytes.Bytes())
|
||||
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(&buf),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "br")
|
||||
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err = h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
|
||||
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
|
||||
}
|
||||
|
||||
actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
|
||||
// Arrange 1: create a test proxy server.
|
||||
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
|
||||
dummyServer := httptest.NewServer(http.HandlerFunc(dummyHandler))
|
||||
defer dummyServer.Close()
|
||||
|
||||
u, err := url.Parse(dummyServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing URL: %v", err)
|
||||
}
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
handler := New(log, "0.0.0.0", 0, u)
|
||||
proxyServer := httptest.NewServer(handler)
|
||||
defer proxyServer.Close()
|
||||
|
||||
u2, err := url.Parse(proxyServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing URL: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(u2.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing port: %v", err)
|
||||
}
|
||||
|
||||
// Arrange 2: start a goroutine to listen for sse events.
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error)
|
||||
sseRespCh := make(chan string)
|
||||
sseListening := make(chan bool) // Coordination channel that ensures the SSE listener is started before notifying the proxy.
|
||||
go func() {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/_templ/reload/events", proxyServer.URL), nil)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
sseListening <- true
|
||||
lines := []string{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
if scanner.Text() == "data: reload" {
|
||||
sseRespCh <- strings.Join(lines, "\n")
|
||||
return
|
||||
}
|
||||
}
|
||||
err = scanner.Err()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// Act: notify the proxy.
|
||||
select { // Either SSE is listening or an error occurred.
|
||||
case <-sseListening:
|
||||
err = NotifyProxy(u2.Hostname(), port)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error notifying proxy: %v", err)
|
||||
}
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
t.Fatalf("unexpected sse response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Assert.
|
||||
select { // Either SSE has a expected response or an error or timeout occurred.
|
||||
case resp := <-sseRespCh:
|
||||
if !strings.Contains(resp, "event: message\ndata: reload") {
|
||||
t.Errorf("expected sse reload event to be received, got: %q", resp)
|
||||
}
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
t.Fatalf("unexpected sse response: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for sse response")
|
||||
}
|
||||
})
|
||||
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "weird-encoding")
|
||||
|
||||
// Act
|
||||
lh := newTestLogHandler(slog.LevelInfo)
|
||||
log := slog.New(lh)
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if len(lh.records) != 1 {
|
||||
var sb strings.Builder
|
||||
for _, record := range lh.records {
|
||||
sb.WriteString(record.Message)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
t.Fatalf("expected 1 log entry, but got %d: \n%s", len(lh.records), sb.String())
|
||||
}
|
||||
record := lh.records[0]
|
||||
if record.Message != unsupportedContentEncoding {
|
||||
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
|
||||
}
|
||||
if record.Level != slog.LevelWarn {
|
||||
t.Errorf("expected warning, got level %v", record.Level)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newTestLogHandler(level slog.Level) *testLogHandler {
|
||||
return &testLogHandler{
|
||||
m: new(sync.Mutex),
|
||||
records: nil,
|
||||
level: level,
|
||||
}
|
||||
}
|
||||
|
||||
type testLogHandler struct {
|
||||
m *sync.Mutex
|
||||
records []slog.Record
|
||||
level slog.Level
|
||||
}
|
||||
|
||||
func (h *testLogHandler) Enabled(ctx context.Context, l slog.Level) bool {
|
||||
return l >= h.level
|
||||
}
|
||||
|
||||
func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
h.m.Lock()
|
||||
defer h.m.Unlock()
|
||||
if r.Level < h.level {
|
||||
return nil
|
||||
}
|
||||
h.records = append(h.records, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *testLogHandler) WithGroup(name string) slog.Handler {
|
||||
return h
|
||||
}
|
||||
|
||||
func TestParseNonce(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
csp string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty csp",
|
||||
csp: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "simple csp",
|
||||
csp: "script-src 'nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C'",
|
||||
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
},
|
||||
{
|
||||
name: "simple csp without single quote",
|
||||
csp: "script-src nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
},
|
||||
{
|
||||
name: "complete csp",
|
||||
csp: "default-src 'self'; frame-ancestors 'self'; form-action 'self'; script-src 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC';",
|
||||
expected: "4VOtk0Uo1l7pwtC",
|
||||
},
|
||||
{
|
||||
name: "mdn example 1",
|
||||
csp: "default-src 'self'",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 2",
|
||||
csp: "default-src 'self' *.trusted.com",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 3",
|
||||
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 3 multiple sources",
|
||||
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com foo.com 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC'",
|
||||
expected: "4VOtk0Uo1l7pwtC",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
nonce := parseNonce(tc.csp)
|
||||
if nonce != tc.expected {
|
||||
t.Errorf("expected nonce to be %s, but got %s", tc.expected, nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
10
templ/cmd/templ/generatecmd/proxy/script.js
Normal file
10
templ/cmd/templ/generatecmd/proxy/script.js
Normal file
@@ -0,0 +1,10 @@
|
||||
(function() {
|
||||
let templ_reloadSrc = window.templ_reloadSrc || new EventSource("/_templ/reload/events");
|
||||
templ_reloadSrc.onmessage = (event) => {
|
||||
if (event && event.data === "reload") {
|
||||
window.location.reload();
|
||||
}
|
||||
};
|
||||
window.templ_reloadSrc = templ_reloadSrc;
|
||||
window.onbeforeunload = () => window.templ_reloadSrc.close();
|
||||
})();
|
108
templ/cmd/templ/generatecmd/run/run_test.go
Normal file
108
templ/cmd/templ/generatecmd/run/run_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package run_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/run"
|
||||
)
|
||||
|
||||
//go:embed testprogram/*
|
||||
var testprogram embed.FS
|
||||
|
||||
func TestGoRun(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test in short mode.")
|
||||
}
|
||||
|
||||
// Copy testprogram to a temporary directory.
|
||||
dir, err := os.MkdirTemp("", "testprogram")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make test dir: %v", err)
|
||||
}
|
||||
files, err := testprogram.ReadDir("testprogram")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read embedded dir: %v", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
srcFileName := "testprogram/" + file.Name()
|
||||
srcData, err := testprogram.ReadFile(srcFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
|
||||
}
|
||||
tgtFileName := filepath.Join(dir, file.Name())
|
||||
tgtFile, err := os.Create(tgtFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
|
||||
}
|
||||
defer tgtFile.Close()
|
||||
if _, err := tgtFile.Write(srcData); err != nil {
|
||||
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
|
||||
}
|
||||
}
|
||||
// Rename the go.mod.embed file to go.mod.
|
||||
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
|
||||
t.Fatalf("failed to rename go.mod.embed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{
|
||||
name: "Well behaved programs get shut down",
|
||||
cmd: "go run .",
|
||||
},
|
||||
{
|
||||
name: "Badly behaved programs get shut down",
|
||||
cmd: "go run . -badly-behaved",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cmd, err := run.Run(ctx, dir, tt.cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run program: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
if err := run.KillAll(); err != nil {
|
||||
t.Fatalf("failed to kill all: %v", err)
|
||||
}
|
||||
|
||||
// Check the parent process is no longer running.
|
||||
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
|
||||
t.Fatalf("process %d is still running", pid)
|
||||
}
|
||||
// Check that the child was stopped.
|
||||
body, err := readResponse("http://localhost:7777")
|
||||
if err == nil {
|
||||
t.Fatalf("child process is still running: %s", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func readResponse(url string) (body string, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
84
templ/cmd/templ/generatecmd/run/run_unix.go
Normal file
84
templ/cmd/templ/generatecmd/run/run_unix.go
Normal file
@@ -0,0 +1,84 @@
|
||||
//go:build unix
|
||||
|
||||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
m = &sync.Mutex{}
|
||||
running = map[string]*exec.Cmd{}
|
||||
)
|
||||
|
||||
func KillAll() (err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
var errs []error
|
||||
for _, cmd := range running {
|
||||
if err := kill(cmd); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
|
||||
}
|
||||
}
|
||||
running = map[string]*exec.Cmd{}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func kill(cmd *exec.Cmd) (err error) {
|
||||
errs := make([]error, 4)
|
||||
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
|
||||
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
|
||||
errs[2] = ignoreExited(cmd.Wait())
|
||||
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func ignoreExited(err error) error {
|
||||
if errors.Is(err, syscall.ESRCH) {
|
||||
return nil
|
||||
}
|
||||
// Ignore *exec.ExitError
|
||||
if _, ok := err.(*exec.ExitError); ok {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
cmd, ok := running[input]
|
||||
if ok {
|
||||
if err := kill(cmd); err != nil {
|
||||
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
delete(running, input)
|
||||
}
|
||||
parts := strings.Fields(input)
|
||||
executable := parts[0]
|
||||
args := []string{}
|
||||
if len(parts) > 1 {
|
||||
args = append(args, parts[1:]...)
|
||||
}
|
||||
|
||||
cmd = exec.CommandContext(ctx, executable, args...)
|
||||
// Wait for the process to finish gracefully before termination.
|
||||
cmd.WaitDelay = time.Second * 3
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = workingDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
running[input] = cmd
|
||||
err = cmd.Start()
|
||||
return
|
||||
}
|
69
templ/cmd/templ/generatecmd/run/run_windows.go
Normal file
69
templ/cmd/templ/generatecmd/run/run_windows.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build windows
|
||||
|
||||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var m = &sync.Mutex{}
|
||||
var running = map[string]*exec.Cmd{}
|
||||
|
||||
func KillAll() (err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, cmd := range running {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
err := kill.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
running = map[string]*exec.Cmd{}
|
||||
return
|
||||
}
|
||||
|
||||
func Stop(cmd *exec.Cmd) (err error) {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
return kill.Run()
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
cmd, ok := running[input]
|
||||
if ok {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
err := kill.Run()
|
||||
if err != nil {
|
||||
return cmd, err
|
||||
}
|
||||
delete(running, input)
|
||||
}
|
||||
parts := strings.Fields(input)
|
||||
executable := parts[0]
|
||||
args := []string{}
|
||||
if len(parts) > 1 {
|
||||
args = append(args, parts[1:]...)
|
||||
}
|
||||
|
||||
cmd = exec.Command(executable, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = workingDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
running[input] = cmd
|
||||
err = cmd.Start()
|
||||
return
|
||||
}
|
3
templ/cmd/templ/generatecmd/run/testprogram/go.mod.embed
Normal file
3
templ/cmd/templ/generatecmd/run/testprogram/go.mod.embed
Normal file
@@ -0,0 +1,3 @@
|
||||
module testprogram
|
||||
|
||||
go 1.23
|
63
templ/cmd/templ/generatecmd/run/testprogram/main.go
Normal file
63
templ/cmd/templ/generatecmd/run/testprogram/main.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This is a test program. It is used only to test the behaviour of the run package.
|
||||
// The run package is supposed to be able to run and stop programs. Those programs may start
|
||||
// child processes, which should also be stopped when the parent program is stopped.
|
||||
|
||||
// For example, running `go run .` will compile an executable and run it.
|
||||
|
||||
// So, this program does nothing. It just waits for a signal to stop.
|
||||
|
||||
// In "Well behaved" mode, the program will stop when it receives a signal.
|
||||
// In "Badly behaved" mode, the program will ignore the signal and continue running.
|
||||
|
||||
// The run package should be able to stop the program in both cases.
|
||||
|
||||
var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
mode := "Well behaved"
|
||||
if *badlyBehavedFlag {
|
||||
mode = "Badly behaved"
|
||||
}
|
||||
fmt.Printf("%s process %d started.\n", mode, os.Getpid())
|
||||
|
||||
// Start a web server on a known port so that we can check that this process is
|
||||
// not running, when it's been started as a child process, and we don't know
|
||||
// its pid.
|
||||
go func() {
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%d", os.Getpid())
|
||||
})
|
||||
err := http.ListenAndServe("127.0.0.1:7777", nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error running web server: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
if !*badlyBehavedFlag {
|
||||
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-sigs:
|
||||
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
fmt.Printf("Process %d still running...\n", os.Getpid())
|
||||
}
|
||||
}
|
||||
}
|
84
templ/cmd/templ/generatecmd/sse/server.go
Normal file
84
templ/cmd/templ/generatecmd/sse/server.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func New() *Handler {
|
||||
return &Handler{
|
||||
m: new(sync.Mutex),
|
||||
requests: map[int64]chan event{},
|
||||
}
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
m *sync.Mutex
|
||||
counter int64
|
||||
requests map[int64]chan event
|
||||
}
|
||||
|
||||
type event struct {
|
||||
Type string
|
||||
Data string
|
||||
}
|
||||
|
||||
// Send an event to all connected clients.
|
||||
func (s *Handler) Send(eventType string, data string) {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
for _, f := range s.requests {
|
||||
f := f
|
||||
go func(f chan event) {
|
||||
f <- event{
|
||||
Type: eventType,
|
||||
Data: data,
|
||||
}
|
||||
}(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
id := atomic.AddInt64(&s.counter, 1)
|
||||
s.m.Lock()
|
||||
events := make(chan event)
|
||||
s.requests[id] = events
|
||||
s.m.Unlock()
|
||||
defer func() {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
delete(s.requests, id)
|
||||
close(events)
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(0)
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if _, err := fmt.Fprintf(w, "event: message\ndata: ping\n\n"); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
timer.Reset(time.Second * 5)
|
||||
case e := <-events:
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Type, e.Data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case <-r.Context().Done():
|
||||
break loop
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
}
|
52
templ/cmd/templ/generatecmd/symlink/symlink_test.go
Normal file
52
templ/cmd/templ/generatecmd/symlink/symlink_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package symlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
)
|
||||
|
||||
func TestSymlink(t *testing.T) {
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
t.Run("can generate if root is symlink", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
symlinkPath := dir + "-symlink"
|
||||
err = os.Symlink(dir, symlinkPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create dir symlink: %v", err)
|
||||
}
|
||||
defer os.Remove(symlinkPath)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(symlinkPath, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
|
||||
// Run the generate command.
|
||||
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
|
||||
Path: symlinkPath,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run generate command: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.go file was created.
|
||||
_, err = os.Stat(path.Join(symlinkPath, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("templates_templ.go was not created: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
@@ -0,0 +1,101 @@
|
||||
package testeventhandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestErrorLocationMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawFileName string
|
||||
errorPositions []token.Position
|
||||
}{
|
||||
{
|
||||
name: "single error outputs location in srcFile",
|
||||
rawFileName: "single_error.templ.error",
|
||||
errorPositions: []token.Position{
|
||||
{Offset: 46, Line: 3, Column: 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple errors all output locations in srcFile",
|
||||
rawFileName: "multiple_errors.templ.error",
|
||||
errorPositions: []token.Position{
|
||||
{Offset: 41, Line: 3, Column: 15},
|
||||
{Offset: 101, Line: 7, Column: 22},
|
||||
{Offset: 126, Line: 10, Column: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
slog := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
|
||||
var fw generatecmd.FileWriterFunc
|
||||
fseh := generatecmd.NewFSEventHandler(slog, ".", false, []generator.GenerateOpt{}, false, false, fw, false)
|
||||
for _, test := range tests {
|
||||
// The raw files cannot end in .templ because they will cause the generator to fail. Instead,
|
||||
// we create a tmp file that ends in .templ only for the duration of the test.
|
||||
rawFile, err := os.Open(test.rawFileName)
|
||||
if err != nil {
|
||||
t.Errorf("%s: Failed to open file %s: %v", test.name, test.rawFileName, err)
|
||||
break
|
||||
}
|
||||
file, err := os.CreateTemp("", fmt.Sprintf("*%s.templ", test.rawFileName))
|
||||
if err != nil {
|
||||
t.Errorf("%s: Failed to create a tmp file at %s: %v", test.name, file.Name(), err)
|
||||
break
|
||||
}
|
||||
defer os.Remove(file.Name())
|
||||
if _, err = io.Copy(file, rawFile); err != nil {
|
||||
t.Errorf("%s: Failed to copy contents from raw file %s to tmp %s: %v", test.name, test.rawFileName, file.Name(), err)
|
||||
}
|
||||
|
||||
event := fsnotify.Event{Name: file.Name(), Op: fsnotify.Write}
|
||||
_, err = fseh.HandleEvent(context.Background(), event)
|
||||
if err == nil {
|
||||
t.Errorf("%s: no error was thrown", test.name)
|
||||
break
|
||||
}
|
||||
list, ok := err.(scanner.ErrorList)
|
||||
for !ok {
|
||||
err = errors.Unwrap(err)
|
||||
if err == nil {
|
||||
t.Errorf("%s: reached end of error wrapping before finding an ErrorList", test.name)
|
||||
break
|
||||
} else {
|
||||
list, ok = err.(scanner.ErrorList)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if len(list) != len(test.errorPositions) {
|
||||
t.Errorf("%s: expected %d errors but got %d", test.name, len(test.errorPositions), len(list))
|
||||
break
|
||||
}
|
||||
for i, err := range list {
|
||||
test.errorPositions[i].Filename = file.Name()
|
||||
diff := cmp.Diff(test.errorPositions[i], err.Pos)
|
||||
if diff != "" {
|
||||
t.Error(diff)
|
||||
t.Error("expected:")
|
||||
t.Error(test.errorPositions[i])
|
||||
t.Error("actual:")
|
||||
t.Error(err.Pos)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,10 @@
|
||||
package testeventhandler
|
||||
|
||||
func invalid(a: string) string {
|
||||
return "foo"
|
||||
}
|
||||
|
||||
templ multipleError(a: string) {
|
||||
<div/>
|
||||
}
|
||||
l
|
@@ -0,0 +1,5 @@
|
||||
package testeventhandler
|
||||
|
||||
templ singleError(a: string) {
|
||||
<div/>
|
||||
}
|
485
templ/cmd/templ/generatecmd/testwatch/generate_test.go
Normal file
485
templ/cmd/templ/generatecmd/testwatch/generate_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package testwatch
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
|
||||
)
|
||||
|
||||
//go:embed testdata/*
|
||||
var testdata embed.FS
|
||||
|
||||
func createTestProject(moduleRoot string) (dir string, err error) {
|
||||
dir, err = os.MkdirTemp("", "templ_watch_test_*")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to make test dir: %w", err)
|
||||
}
|
||||
files, err := testdata.ReadDir("testdata")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
src := filepath.Join("testdata", file.Name())
|
||||
data, err := testdata.ReadFile(src)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(dir, file.Name())
|
||||
if file.Name() == "go.mod.embed" {
|
||||
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
|
||||
target = filepath.Join(dir, "go.mod")
|
||||
}
|
||||
err = os.WriteFile(target, data, 0660)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func replaceInFile(name, src, tgt string) error {
|
||||
data, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated := strings.Replace(string(data), src, tgt, -1)
|
||||
return os.WriteFile(name, []byte(updated), 0660)
|
||||
}
|
||||
|
||||
func getPort() (port int, err error) {
|
||||
var a *net.TCPAddr
|
||||
if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getHTML(url string) (doc *goquery.Document, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get %q: %w", url, err)
|
||||
}
|
||||
return goquery.NewDocumentFromReader(resp.Body)
|
||||
}
|
||||
|
||||
func TestCanAccessDirect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Assert.
|
||||
doc, err := getHTML(args.AppURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
countText := doc.Find(`div[data-testid="count"]`).Text()
|
||||
actualCount, err := strconv.Atoi(countText)
|
||||
if err != nil {
|
||||
t.Fatalf("got count %q instead of integer", countText)
|
||||
}
|
||||
if actualCount < 1 {
|
||||
t.Errorf("expected count >= 1, got %d", actualCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanAccessViaProxy(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Assert.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
countText := doc.Find(`div[data-testid="count"]`).Text()
|
||||
actualCount, err := strconv.Atoi(countText)
|
||||
if err != nil {
|
||||
t.Fatalf("got count %q instead of integer", countText)
|
||||
}
|
||||
if actualCount < 1 {
|
||||
t.Errorf("expected count >= 1, got %d", actualCount)
|
||||
}
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data string
|
||||
}
|
||||
|
||||
func readSSE(ctx context.Context, url string, sse chan<- Event) (err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var e Event
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
sse <- e
|
||||
e = Event{}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
e.Type = line[len("event: "):]
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
e.Data = line[len("data: "):]
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func TestFileModificationsResultInSSEWithGzip(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Start the SSE check.
|
||||
events := make(chan Event)
|
||||
var eventsErr error
|
||||
go func() {
|
||||
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
|
||||
}()
|
||||
|
||||
// Assert data is expected.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
|
||||
t.Errorf("expected %q, got %q", "Original", text)
|
||||
}
|
||||
|
||||
// Change file.
|
||||
templFile := filepath.Join(args.AppDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification">Updated</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
// Give the filesystem watcher a few seconds.
|
||||
var reloadCount int
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Data == "reload" {
|
||||
reloadCount++
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if reloadCount == 0 {
|
||||
t.Error("failed to receive SSE about update after 5 seconds")
|
||||
}
|
||||
|
||||
// Check to see if there were any errors.
|
||||
if eventsErr != nil {
|
||||
t.Errorf("error reading events: %v", err)
|
||||
}
|
||||
|
||||
// See results in browser immediately.
|
||||
doc, err = getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
|
||||
t.Errorf("expected %q, got %q", "Updated", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileModificationsResultInSSE(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Start the SSE check.
|
||||
events := make(chan Event)
|
||||
var eventsErr error
|
||||
go func() {
|
||||
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
|
||||
}()
|
||||
|
||||
// Assert data is expected.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
|
||||
t.Errorf("expected %q, got %q", "Original", text)
|
||||
}
|
||||
|
||||
// Change file.
|
||||
templFile := filepath.Join(args.AppDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification">Updated</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
// Give the filesystem watcher a few seconds.
|
||||
var reloadCount int
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Data == "reload" {
|
||||
reloadCount++
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if reloadCount == 0 {
|
||||
t.Error("failed to receive SSE about update after 5 seconds")
|
||||
}
|
||||
|
||||
// Check to see if there were any errors.
|
||||
if eventsErr != nil {
|
||||
t.Errorf("error reading events: %v", err)
|
||||
}
|
||||
|
||||
// See results in browser immediately.
|
||||
doc, err = getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
|
||||
t.Errorf("expected %q, got %q", "Updated", text)
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestArgs(modRoot, appDir string, appPort int, proxyBind string, proxyPort int) TestArgs {
|
||||
return TestArgs{
|
||||
ModRoot: modRoot,
|
||||
AppDir: appDir,
|
||||
AppPort: appPort,
|
||||
AppURL: fmt.Sprintf("http://localhost:%d", appPort),
|
||||
ProxyBind: proxyBind,
|
||||
ProxyPort: proxyPort,
|
||||
ProxyURL: fmt.Sprintf("http://%s:%d", proxyBind, proxyPort),
|
||||
}
|
||||
}
|
||||
|
||||
type TestArgs struct {
|
||||
ModRoot string
|
||||
AppDir string
|
||||
AppPort int
|
||||
AppURL string
|
||||
ProxyBind string
|
||||
ProxyPort int
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
func Setup(gzipEncoding bool) (args TestArgs, teardown func(t *testing.T), err error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("could not find working dir: %w", err)
|
||||
}
|
||||
moduleRoot, err := modcheck.WalkUp(wd)
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
|
||||
}
|
||||
|
||||
appDir, err := createTestProject(moduleRoot)
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to create test project: %v", err)
|
||||
}
|
||||
appPort, err := getPort()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
|
||||
}
|
||||
proxyPort, err := getPort()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
|
||||
}
|
||||
proxyBind := "localhost"
|
||||
|
||||
args = NewTestArgs(moduleRoot, appDir, appPort, proxyBind, proxyPort)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var cmdErr error
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
command := fmt.Sprintf("go run . -port %d", args.AppPort)
|
||||
if gzipEncoding {
|
||||
command += " -gzip true"
|
||||
}
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
cmdErr = generatecmd.Run(ctx, log, generatecmd.Arguments{
|
||||
Path: appDir,
|
||||
Watch: true,
|
||||
OpenBrowser: false,
|
||||
Command: command,
|
||||
ProxyBind: proxyBind,
|
||||
ProxyPort: proxyPort,
|
||||
Proxy: args.AppURL,
|
||||
NotifyProxy: false,
|
||||
WorkerCount: 0,
|
||||
GenerateSourceMapVisualisations: false,
|
||||
IncludeVersion: false,
|
||||
IncludeTimestamp: false,
|
||||
PPROFPort: 0,
|
||||
KeepOrphanedFiles: false,
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait for server to start.
|
||||
if err = waitForURL(args.AppURL); err != nil {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return args, teardown, fmt.Errorf("failed to start app server, command error %v: %w", cmdErr, err)
|
||||
}
|
||||
if err = waitForURL(args.ProxyURL); err != nil {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return args, teardown, fmt.Errorf("failed to start proxy server, command error %v: %w", cmdErr, err)
|
||||
}
|
||||
|
||||
// Wait for exit.
|
||||
teardown = func(t *testing.T) {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if cmdErr != nil {
|
||||
t.Errorf("failed to run generate cmd: %v", err)
|
||||
}
|
||||
|
||||
if err = os.RemoveAll(appDir); err != nil {
|
||||
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
|
||||
}
|
||||
}
|
||||
return args, teardown, err
|
||||
}
|
||||
|
||||
func waitForURL(url string) (err error) {
|
||||
var tries int
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
if tries > 20 {
|
||||
return err
|
||||
}
|
||||
tries++
|
||||
var resp *http.Response
|
||||
resp, err = http.Get(url)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to get %q: %v\n", url, err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Printf("failed to get %q: %v\n", url, err)
|
||||
err = fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateReturnsErrors(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("could not find working dir: %v", err)
|
||||
}
|
||||
moduleRoot, err := modcheck.WalkUp(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("could not find local templ go.mod file: %v", err)
|
||||
}
|
||||
|
||||
appDir, err := createTestProject(moduleRoot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err = os.RemoveAll(appDir); err != nil {
|
||||
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Break the HTML.
|
||||
templFile := filepath.Join(appDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification" -unclosed div-</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
// Run.
|
||||
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
|
||||
Path: appDir,
|
||||
Watch: false,
|
||||
IncludeVersion: false,
|
||||
IncludeTimestamp: false,
|
||||
KeepOrphanedFiles: false,
|
||||
})
|
||||
if err == nil {
|
||||
t.Errorf("expected generation error, got %v", err)
|
||||
}
|
||||
}
|
7
templ/cmd/templ/generatecmd/testwatch/testdata/go.mod.embed
vendored
Normal file
7
templ/cmd/templ/generatecmd/testwatch/testdata/go.mod.embed
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
module templ/testproject
|
||||
|
||||
go 1.23
|
||||
|
||||
require github.com/a-h/templ v0.2.513 // indirect
|
||||
|
||||
replace github.com/a-h/templ => {moduleRoot}
|
2
templ/cmd/templ/generatecmd/testwatch/testdata/go.sum
vendored
Normal file
2
templ/cmd/templ/generatecmd/testwatch/testdata/go.sum
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
81
templ/cmd/templ/generatecmd/testwatch/testdata/main.go
vendored
Normal file
81
templ/cmd/templ/generatecmd/testwatch/testdata/main.go
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
|
||||
type GzipResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
defer gzw.Close()
|
||||
|
||||
_, err := gzw.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = gzw.Close()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
|
||||
return w.w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) WriteHeader(statusCode int) {
|
||||
w.w.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
|
||||
var useGzip = flag.Bool("gzip", false, "Toggle gzip encoding")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *flagPort == 0 {
|
||||
fmt.Println("missing port flag")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var count int
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if useGzip != nil && *useGzip {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w = &GzipResponseWriter{w: w}
|
||||
}
|
||||
|
||||
count++
|
||||
c := Page(count)
|
||||
h := templ.Handler(c)
|
||||
h.ErrorHandler = func(r *http.Request, err error) http.Handler {
|
||||
slog.Error("failed to render template", slog.Any("error", err))
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error listening: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
17
templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ
vendored
Normal file
17
templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ Page(count int) {
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>templ test page</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Count</h1>
|
||||
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
|
||||
<div data-testid="modification">Original</div>
|
||||
</body>
|
||||
</html>
|
||||
}
|
55
templ/cmd/templ/generatecmd/testwatch/testdata/templates_templ.go
vendored
Normal file
55
templ/cmd/templ/generatecmd/testwatch/testdata/templates_templ.go
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package main
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
import "fmt"
|
||||
|
||||
func Page(count int) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var2 string
|
||||
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ`, Line: 13, Col: 54}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
166
templ/cmd/templ/generatecmd/watcher/watch.go
Normal file
166
templ/cmd/templ/generatecmd/watcher/watch.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
func Recursive(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
watchPattern *regexp.Regexp,
|
||||
out chan fsnotify.Event,
|
||||
errors chan error,
|
||||
) (w *RecursiveWatcher, err error) {
|
||||
fsnw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors)
|
||||
go w.loop()
|
||||
return w, w.Add(path)
|
||||
}
|
||||
|
||||
func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher {
|
||||
return &RecursiveWatcher{
|
||||
ctx: ctx,
|
||||
w: w,
|
||||
WatchPattern: watchPattern,
|
||||
Events: events,
|
||||
Errors: errors,
|
||||
timers: make(map[timerKey]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
// WalkFiles walks the file tree rooted at path, sending a Create event for each
|
||||
// file it encounters.
|
||||
func WalkFiles(ctx context.Context, path string, watchPattern *regexp.Regexp, out chan fsnotify.Event) (err error) {
|
||||
rootPath := path
|
||||
fileSystem := os.DirFS(rootPath)
|
||||
return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
absPath, err := filepath.Abs(filepath.Join(rootPath, path))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() && shouldSkipDir(absPath) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !watchPattern.MatchString(absPath) {
|
||||
return nil
|
||||
}
|
||||
out <- fsnotify.Event{
|
||||
Name: absPath,
|
||||
Op: fsnotify.Create,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type RecursiveWatcher struct {
|
||||
ctx context.Context
|
||||
w *fsnotify.Watcher
|
||||
WatchPattern *regexp.Regexp
|
||||
Events chan fsnotify.Event
|
||||
Errors chan error
|
||||
timerMu sync.Mutex
|
||||
timers map[timerKey]*time.Timer
|
||||
}
|
||||
|
||||
type timerKey struct {
|
||||
name string
|
||||
op fsnotify.Op
|
||||
}
|
||||
|
||||
func timerKeyFromEvent(event fsnotify.Event) timerKey {
|
||||
return timerKey{
|
||||
name: event.Name,
|
||||
op: event.Op,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) Close() error {
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case event, ok := <-w.w.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Create) {
|
||||
if err := w.Add(event.Name); err != nil {
|
||||
w.Errors <- err
|
||||
}
|
||||
}
|
||||
// Only notify on templ related files.
|
||||
if !w.WatchPattern.MatchString(event.Name) {
|
||||
continue
|
||||
}
|
||||
tk := timerKeyFromEvent(event)
|
||||
w.timerMu.Lock()
|
||||
t, ok := w.timers[tk]
|
||||
w.timerMu.Unlock()
|
||||
if !ok {
|
||||
t = time.AfterFunc(100*time.Millisecond, func() {
|
||||
w.Events <- event
|
||||
})
|
||||
w.timerMu.Lock()
|
||||
w.timers[tk] = t
|
||||
w.timerMu.Unlock()
|
||||
continue
|
||||
}
|
||||
t.Reset(100 * time.Millisecond)
|
||||
case err, ok := <-w.w.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.Errors <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) Add(dir string) error {
|
||||
return filepath.WalkDir(dir, func(dir string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if shouldSkipDir(dir) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return w.w.Add(dir)
|
||||
})
|
||||
}
|
||||
|
||||
func shouldSkipDir(dir string) bool {
|
||||
if dir == "." {
|
||||
return false
|
||||
}
|
||||
if dir == "vendor" || dir == "node_modules" {
|
||||
return true
|
||||
}
|
||||
_, name := path.Split(dir)
|
||||
// These directories are ignored by the Go tool.
|
||||
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
133
templ/cmd/templ/generatecmd/watcher/watch_test.go
Normal file
133
templ/cmd/templ/generatecmd/watcher/watch_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
func TestWatchDebouncesDuplicates(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(300 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
event1 fsnotify.Event
|
||||
event2 fsnotify.Event
|
||||
}{
|
||||
// Different files
|
||||
{fsnotify.Event{Name: "test.templ"}, fsnotify.Event{Name: "test2.templ"}},
|
||||
// Different operations
|
||||
{
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Write},
|
||||
},
|
||||
// Different operations and files
|
||||
{
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
|
||||
fsnotify.Event{Name: "test2.templ", Op: fsnotify.Write},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- test.event1
|
||||
rw.w.Events <- test.event2
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(300 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
<-time.After(200 * time.Millisecond)
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user