Changed: DB Params

This commit is contained in:
2025-03-20 12:35:13 +01:00
parent 8640a12439
commit b71b3d12ca
822 changed files with 134218 additions and 0 deletions

View File

@@ -0,0 +1,166 @@
package fmtcmd
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"runtime"
"sync"
"time"
"github.com/a-h/templ/cmd/templ/imports"
"github.com/a-h/templ/cmd/templ/processor"
parser "github.com/a-h/templ/parser/v2"
"github.com/natefinch/atomic"
)
type Arguments struct {
FailIfChanged bool
ToStdout bool
StdinFilepath string
Files []string
WorkerCount int
}
func Run(log *slog.Logger, stdin io.Reader, stdout io.Writer, args Arguments) (err error) {
// If no files are provided, read from stdin and write to stdout.
if len(args.Files) == 0 {
out, _ := format(writeToWriter(stdout), readFromReader(stdin, args.StdinFilepath), true)
return out
}
process := func(fileName string) (error, bool) {
read := readFromFile(fileName)
write := writeToFile
if args.ToStdout {
write = writeToWriter(stdout)
}
writeIfUnchanged := args.ToStdout
return format(write, read, writeIfUnchanged)
}
dir := args.Files[0]
return NewFormatter(log, dir, process, args.WorkerCount, args.FailIfChanged).Run()
}
type Formatter struct {
Log *slog.Logger
Dir string
Process func(fileName string) (error, bool)
WorkerCount int
FailIfChange bool
}
func NewFormatter(log *slog.Logger, dir string, process func(fileName string) (error, bool), workerCount int, failIfChange bool) *Formatter {
f := &Formatter{
Log: log,
Dir: dir,
Process: process,
WorkerCount: workerCount,
FailIfChange: failIfChange,
}
if f.WorkerCount == 0 {
f.WorkerCount = runtime.NumCPU()
}
return f
}
func (f *Formatter) Run() (err error) {
changesMade := 0
start := time.Now()
results := make(chan processor.Result)
f.Log.Debug("Walking directory", slog.String("path", f.Dir))
go processor.Process(f.Dir, f.Process, f.WorkerCount, results)
var successCount, errorCount int
for r := range results {
if r.ChangesMade {
changesMade += 1
}
if r.Error != nil {
f.Log.Error(r.FileName, slog.Any("error", r.Error))
errorCount++
continue
}
f.Log.Debug(r.FileName, slog.Duration("duration", r.Duration))
successCount++
}
if f.FailIfChange && changesMade > 0 {
f.Log.Error("Templates were valid but not properly formatted", slog.Int("count", successCount+errorCount), slog.Int("changed", changesMade), slog.Int("errors", errorCount), slog.Duration("duration", time.Since(start)))
return fmt.Errorf("templates were not formatted properly")
}
f.Log.Info("Format Complete", slog.Int("count", successCount+errorCount), slog.Int("errors", errorCount), slog.Int("changed", changesMade), slog.Duration("duration", time.Since(start)))
if errorCount > 0 {
return fmt.Errorf("formatting failed")
}
return
}
type reader func() (fileName, src string, err error)
func readFromReader(r io.Reader, stdinFilepath string) func() (fileName, src string, err error) {
return func() (fileName, src string, err error) {
b, err := io.ReadAll(r)
if err != nil {
return "", "", fmt.Errorf("failed to read stdin: %w", err)
}
return stdinFilepath, string(b), nil
}
}
func readFromFile(name string) reader {
return func() (fileName, src string, err error) {
b, err := os.ReadFile(name)
if err != nil {
return "", "", fmt.Errorf("failed to read file %q: %w", fileName, err)
}
return name, string(b), nil
}
}
type writer func(fileName, tgt string) error
var mu sync.Mutex
func writeToWriter(w io.Writer) func(fileName, tgt string) error {
return func(fileName, tgt string) error {
mu.Lock()
defer mu.Unlock()
_, err := w.Write([]byte(tgt))
return err
}
}
func writeToFile(fileName, tgt string) error {
return atomic.WriteFile(fileName, bytes.NewBufferString(tgt))
}
func format(write writer, read reader, writeIfUnchanged bool) (err error, fileChanged bool) {
fileName, src, err := read()
if err != nil {
return err, false
}
t, err := parser.ParseString(src)
if err != nil {
return err, false
}
t.Filepath = fileName
t, err = imports.Process(t)
if err != nil {
return err, false
}
w := new(bytes.Buffer)
if err = t.Write(w); err != nil {
return fmt.Errorf("formatting error: %w", err), false
}
fileChanged = (src != w.String())
if !writeIfUnchanged && !fileChanged {
return nil, fileChanged
}
return write(fileName, w.String()), fileChanged
}

View File

@@ -0,0 +1,163 @@
package fmtcmd
import (
_ "embed"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/txtar"
)
//go:embed testdata.txtar
var testDataTxTar []byte
type testProject struct {
dir string
cleanup func()
testFiles map[string]testFile
}
type testFile struct {
name string
input, expected string
}
func setupProjectDir() (tp testProject, err error) {
tp.dir, err = os.MkdirTemp("", "fmtcmd_test_*")
if err != nil {
return tp, fmt.Errorf("failed to make test dir: %w", err)
}
tp.testFiles = make(map[string]testFile)
testData := txtar.Parse(testDataTxTar)
for i := 0; i < len(testData.Files); i += 2 {
file := testData.Files[i]
err = os.WriteFile(filepath.Join(tp.dir, file.Name), file.Data, 0660)
if err != nil {
return tp, fmt.Errorf("failed to write file: %w", err)
}
tp.testFiles[file.Name] = testFile{
name: filepath.Join(tp.dir, file.Name),
input: string(file.Data),
expected: string(testData.Files[i+1].Data),
}
}
tp.cleanup = func() {
os.RemoveAll(tp.dir)
}
return tp, nil
}
func TestFormat(t *testing.T) {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
t.Run("can format a single file from stdin to stdout", func(t *testing.T) {
tp, err := setupProjectDir()
if err != nil {
t.Fatalf("failed to setup project dir: %v", err)
}
defer tp.cleanup()
stdin := strings.NewReader(tp.testFiles["a.templ"].input)
stdout := new(strings.Builder)
if err = Run(log, stdin, stdout, Arguments{
ToStdout: true,
}); err != nil {
t.Fatalf("failed to run format command: %v", err)
}
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, stdout.String()); diff != "" {
t.Error(diff)
}
})
t.Run("can process a single file to stdout", func(t *testing.T) {
tp, err := setupProjectDir()
if err != nil {
t.Fatalf("failed to setup project dir: %v", err)
}
defer tp.cleanup()
stdout := new(strings.Builder)
if err = Run(log, nil, stdout, Arguments{
ToStdout: true,
Files: []string{
tp.testFiles["a.templ"].name,
},
FailIfChanged: false,
}); err != nil {
t.Fatalf("failed to run format command: %v", err)
}
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, stdout.String()); diff != "" {
t.Error(diff)
}
})
t.Run("can process a single file in place", func(t *testing.T) {
tp, err := setupProjectDir()
if err != nil {
t.Fatalf("failed to setup project dir: %v", err)
}
defer tp.cleanup()
if err = Run(log, nil, nil, Arguments{
Files: []string{
tp.testFiles["a.templ"].name,
},
FailIfChanged: false,
}); err != nil {
t.Fatalf("failed to run format command: %v", err)
}
data, err := os.ReadFile(tp.testFiles["a.templ"].name)
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, string(data)); diff != "" {
t.Error(diff)
}
})
t.Run("fails when fail flag used and change occurs", func(t *testing.T) {
tp, err := setupProjectDir()
if err != nil {
t.Fatalf("failed to setup project dir: %v", err)
}
defer tp.cleanup()
if err = Run(log, nil, nil, Arguments{
Files: []string{
tp.testFiles["a.templ"].name,
},
FailIfChanged: true,
}); err == nil {
t.Fatal("command should have exited with an error and did not")
}
data, err := os.ReadFile(tp.testFiles["a.templ"].name)
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, string(data)); diff != "" {
t.Error(diff)
}
})
t.Run("passes when fail flag used and no change occurs", func(t *testing.T) {
tp, err := setupProjectDir()
if err != nil {
t.Fatalf("failed to setup project dir: %v", err)
}
defer tp.cleanup()
if err = Run(log, nil, nil, Arguments{
Files: []string{
tp.testFiles["c.templ"].name,
},
FailIfChanged: true,
}); err != nil {
t.Fatalf("failed to run format command: %v", err)
}
data, err := os.ReadFile(tp.testFiles["c.templ"].name)
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
if diff := cmp.Diff(tp.testFiles["c.templ"].expected, string(data)); diff != "" {
t.Error(diff)
}
})
}

View File

@@ -0,0 +1,54 @@
-- a.templ --
package test
templ a() {
<div><p class={templ.Class("mapped")}>A
</p></div>
}
-- a.templ --
package test
templ a() {
<div>
<p class={ templ.Class("mapped") }>
A
</p>
</div>
}
-- b.templ --
package test
templ b() {
<div><p>B
</p></div>
}
-- b.templ --
package test
templ b() {
<div>
<p>
B
</p>
</div>
}
-- c.templ --
package test
templ c() {
<div>
<p>
C
</p>
</div>
}
-- c.templ --
package test
templ c() {
<div>
<p>
C
</p>
</div>
}

View File

@@ -0,0 +1,403 @@
package generatecmd
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/a-h/templ"
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
"github.com/a-h/templ/cmd/templ/generatecmd/proxy"
"github.com/a-h/templ/cmd/templ/generatecmd/run"
"github.com/a-h/templ/cmd/templ/generatecmd/watcher"
"github.com/a-h/templ/generator"
"github.com/cenkalti/backoff/v4"
"github.com/cli/browser"
"github.com/fsnotify/fsnotify"
)
const defaultWatchPattern = `(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)`
func NewGenerate(log *slog.Logger, args Arguments) (g *Generate, err error) {
g = &Generate{
Log: log,
Args: &args,
}
if g.Args.WorkerCount == 0 {
g.Args.WorkerCount = runtime.NumCPU()
}
if g.Args.WatchPattern == "" {
g.Args.WatchPattern = defaultWatchPattern
}
g.WatchPattern, err = regexp.Compile(g.Args.WatchPattern)
if err != nil {
return nil, fmt.Errorf("failed to compile watch pattern %q: %w", g.Args.WatchPattern, err)
}
return g, nil
}
type Generate struct {
Log *slog.Logger
Args *Arguments
WatchPattern *regexp.Regexp
}
type GenerationEvent struct {
Event fsnotify.Event
Updated bool
GoUpdated bool
TextUpdated bool
}
func (cmd Generate) Run(ctx context.Context) (err error) {
if cmd.Args.NotifyProxy {
return proxy.NotifyProxy(cmd.Args.ProxyBind, cmd.Args.ProxyPort)
}
if cmd.Args.Watch && cmd.Args.FileName != "" {
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
}
writingToWriter := cmd.Args.FileWriter != nil
if cmd.Args.FileName == "" && writingToWriter {
return fmt.Errorf("only a single file can be output to stdout, add the -f flag to specify the file to generate code for")
}
// Default to writing to files.
if cmd.Args.FileWriter == nil {
cmd.Args.FileWriter = FileWriter
}
if cmd.Args.PPROFPort > 0 {
go func() {
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", cmd.Args.PPROFPort), nil)
}()
}
// Use absolute path.
if !path.IsAbs(cmd.Args.Path) {
cmd.Args.Path, err = filepath.Abs(cmd.Args.Path)
if err != nil {
return fmt.Errorf("failed to get absolute path: %w", err)
}
}
// Configure generator.
var opts []generator.GenerateOpt
if cmd.Args.IncludeVersion {
opts = append(opts, generator.WithVersion(templ.Version()))
}
if cmd.Args.IncludeTimestamp {
opts = append(opts, generator.WithTimestamp(time.Now()))
}
// Check the version of the templ module.
if err := modcheck.Check(cmd.Args.Path); err != nil {
cmd.Log.Warn("templ version check: " + err.Error())
}
fseh := NewFSEventHandler(
cmd.Log,
cmd.Args.Path,
cmd.Args.Watch,
opts,
cmd.Args.GenerateSourceMapVisualisations,
cmd.Args.KeepOrphanedFiles,
cmd.Args.FileWriter,
cmd.Args.Lazy,
)
// If we're processing a single file, don't bother setting up the channels/multithreaing.
if cmd.Args.FileName != "" {
_, err = fseh.HandleEvent(ctx, fsnotify.Event{
Name: cmd.Args.FileName,
Op: fsnotify.Create,
})
return err
}
// Start timer.
start := time.Now()
// Create channels:
// For the initial filesystem walk and subsequent (optional) fsnotify events.
events := make(chan fsnotify.Event)
// Count of events currently being processed by the event handler.
var eventsWG sync.WaitGroup
// Used to check that the event handler has completed.
var eventHandlerWG sync.WaitGroup
// For errs from the watcher.
errs := make(chan error)
// Tracks whether errors occurred during the generation process.
var errorCount atomic.Int64
// For triggering actions after generation has completed.
postGeneration := make(chan *GenerationEvent, 256)
// Used to check that the post-generation handler has completed.
var postGenerationWG sync.WaitGroup
var postGenerationEventsWG sync.WaitGroup
// Waitgroup for the push process.
var pushHandlerWG sync.WaitGroup
// Start process to push events into the channel.
pushHandlerWG.Add(1)
go func() {
defer pushHandlerWG.Done()
defer close(events)
cmd.Log.Debug(
"Walking directory",
slog.String("path", cmd.Args.Path),
slog.Bool("devMode", cmd.Args.Watch),
)
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil {
cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
return
}
if !cmd.Args.Watch {
cmd.Log.Debug("Dev mode not enabled, process can finish early")
return
}
cmd.Log.Info("Watching files")
rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.WatchPattern, events, errs)
if err != nil {
cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)}
return
}
cmd.Log.Debug("Waiting for context to be cancelled to stop watching files")
<-ctx.Done()
cmd.Log.Debug("Context cancelled, closing watcher")
if err := rw.Close(); err != nil {
cmd.Log.Error("Failed to close watcher", slog.Any("error", err))
}
cmd.Log.Debug("Waiting for events to be processed")
eventsWG.Wait()
cmd.Log.Debug(
"All pending events processed, waiting for pending post-generation events to complete",
)
postGenerationEventsWG.Wait()
cmd.Log.Debug(
"All post-generation events processed, deleting watch mode text files",
slog.Int64("errorCount", errorCount.Load()),
)
fileEvents := make(chan fsnotify.Event)
go func() {
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, fileEvents); err != nil {
cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err))
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
return
}
close(fileEvents)
}()
for event := range fileEvents {
if strings.HasSuffix(event.Name, "_templ.txt") {
if err = os.Remove(event.Name); err != nil {
cmd.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
}
}
}
}()
// Start process to handle events.
eventHandlerWG.Add(1)
sem := make(chan struct{}, cmd.Args.WorkerCount)
go func() {
defer eventHandlerWG.Done()
defer close(postGeneration)
cmd.Log.Debug("Starting event handler")
for event := range events {
eventsWG.Add(1)
sem <- struct{}{}
go func(event fsnotify.Event) {
cmd.Log.Debug("Processing file", slog.String("file", event.Name))
defer eventsWG.Done()
defer func() { <-sem }()
r, err := fseh.HandleEvent(ctx, event)
if err != nil {
errs <- err
}
if !(r.GoUpdated || r.TextUpdated) {
cmd.Log.Debug("File not updated", slog.String("file", event.Name))
return
}
e := &GenerationEvent{
Event: event,
Updated: r.Updated,
GoUpdated: r.GoUpdated,
TextUpdated: r.TextUpdated,
}
cmd.Log.Debug("File updated", slog.String("file", event.Name))
postGeneration <- e
}(event)
}
// Wait for all events to be processed before closing.
eventsWG.Wait()
}()
// Start process to handle post-generation events.
var updates int
postGenerationWG.Add(1)
var firstPostGenerationExecuted bool
go func() {
defer close(errs)
defer postGenerationWG.Done()
cmd.Log.Debug("Starting post-generation handler")
timeout := time.NewTimer(time.Hour * 24 * 365)
var goUpdated, textUpdated bool
var p *proxy.Handler
for {
select {
case ge := <-postGeneration:
if ge == nil {
cmd.Log.Debug("Post-generation event channel closed, exiting")
return
}
goUpdated = goUpdated || ge.GoUpdated
textUpdated = textUpdated || ge.TextUpdated
if goUpdated || textUpdated {
updates++
}
// Reset timer.
if !timeout.Stop() {
<-timeout.C
}
timeout.Reset(time.Millisecond * 100)
case <-timeout.C:
if !goUpdated && !textUpdated {
// Nothing to process, reset timer and wait again.
timeout.Reset(time.Hour * 24 * 365)
break
}
postGenerationEventsWG.Add(1)
if cmd.Args.Command != "" && goUpdated {
cmd.Log.Debug("Executing command", slog.String("command", cmd.Args.Command))
if cmd.Args.Watch {
os.Setenv("TEMPL_DEV_MODE", "true")
}
if _, err := run.Run(ctx, cmd.Args.Path, cmd.Args.Command); err != nil {
cmd.Log.Error("Error executing command", slog.Any("error", err))
}
}
if !firstPostGenerationExecuted {
cmd.Log.Debug("First post-generation event received, starting proxy")
firstPostGenerationExecuted = true
p, err = cmd.StartProxy(ctx)
if err != nil {
cmd.Log.Error("Failed to start proxy", slog.Any("error", err))
}
}
// Send server-sent event.
if p != nil && (textUpdated || goUpdated) {
cmd.Log.Debug("Sending reload event")
p.SendSSE("message", "reload")
}
postGenerationEventsWG.Done()
// Reset timer.
timeout.Reset(time.Millisecond * 100)
textUpdated = false
goUpdated = false
}
}
}()
// Read errors.
for err := range errs {
if err == nil {
continue
}
if errors.Is(err, FatalError{}) {
cmd.Log.Debug("Fatal error, exiting")
return err
}
cmd.Log.Error("Error", slog.Any("error", err))
errorCount.Add(1)
}
// Wait for everything to complete.
cmd.Log.Debug("Waiting for push handler to complete")
pushHandlerWG.Wait()
cmd.Log.Debug("Waiting for event handler to complete")
eventHandlerWG.Wait()
cmd.Log.Debug("Waiting for post-generation handler to complete")
postGenerationWG.Wait()
if cmd.Args.Command != "" {
cmd.Log.Debug("Killing command", slog.String("command", cmd.Args.Command))
if err := run.KillAll(); err != nil {
cmd.Log.Error("Error killing command", slog.Any("error", err))
}
}
// Check for errors after everything has completed.
if errorCount.Load() > 0 {
return fmt.Errorf("generation completed with %d errors", errorCount.Load())
}
cmd.Log.Info(
"Complete",
slog.Int("updates", updates),
slog.Duration("duration", time.Since(start)),
)
return nil
}
func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err error) {
if cmd.Args.Proxy == "" {
cmd.Log.Debug("No proxy URL specified, not starting proxy")
return nil, nil
}
var target *url.URL
target, err = url.Parse(cmd.Args.Proxy)
if err != nil {
return nil, FatalError{Err: fmt.Errorf("failed to parse proxy URL: %w", err)}
}
if cmd.Args.ProxyPort == 0 {
cmd.Args.ProxyPort = 7331
}
if cmd.Args.ProxyBind == "" {
cmd.Args.ProxyBind = "127.0.0.1"
}
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
go func() {
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
cmd.Log.Error("Proxy failed", slog.Any("error", err))
}
}()
if !cmd.Args.OpenBrowser {
cmd.Log.Debug("Not opening browser")
return p, nil
}
go func() {
cmd.Log.Debug("Waiting for proxy to be ready", slog.String("url", p.URL))
backoff := backoff.NewExponentialBackOff()
backoff.InitialInterval = time.Second
var client http.Client
client.Timeout = 1 * time.Second
for {
if _, err := client.Get(p.URL); err == nil {
break
}
d := backoff.NextBackOff()
cmd.Log.Debug(
"Proxy not ready, retrying",
slog.String("url", p.URL),
slog.Any("backoff", d),
)
time.Sleep(d)
}
if err := browser.OpenURL(p.URL); err != nil {
cmd.Log.Error("Failed to open browser", slog.Any("error", err))
}
}()
return p, nil
}

View File

@@ -0,0 +1,366 @@
package generatecmd
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"fmt"
"go/format"
"go/scanner"
"go/token"
"io"
"log/slog"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/a-h/templ/cmd/templ/visualize"
"github.com/a-h/templ/generator"
"github.com/a-h/templ/parser/v2"
"github.com/fsnotify/fsnotify"
)
type FileWriterFunc func(name string, contents []byte) error
func FileWriter(fileName string, contents []byte) error {
return os.WriteFile(fileName, contents, 0o644)
}
func WriterFileWriter(w io.Writer) FileWriterFunc {
return func(_ string, contents []byte) error {
_, err := w.Write(contents)
return err
}
}
func NewFSEventHandler(
log *slog.Logger,
dir string,
devMode bool,
genOpts []generator.GenerateOpt,
genSourceMapVis bool,
keepOrphanedFiles bool,
fileWriter FileWriterFunc,
lazy bool,
) *FSEventHandler {
if !path.IsAbs(dir) {
dir, _ = filepath.Abs(dir)
}
fseh := &FSEventHandler{
Log: log,
dir: dir,
fileNameToLastModTime: make(map[string]time.Time),
fileNameToLastModTimeMutex: &sync.Mutex{},
fileNameToError: make(map[string]struct{}),
fileNameToErrorMutex: &sync.Mutex{},
fileNameToOutput: make(map[string]generator.GeneratorOutput),
fileNameToOutputMutex: &sync.Mutex{},
devMode: devMode,
hashes: make(map[string][sha256.Size]byte),
hashesMutex: &sync.Mutex{},
genOpts: genOpts,
genSourceMapVis: genSourceMapVis,
keepOrphanedFiles: keepOrphanedFiles,
writer: fileWriter,
lazy: lazy,
}
return fseh
}
type FSEventHandler struct {
Log *slog.Logger
// dir is the root directory being processed.
dir string
fileNameToLastModTime map[string]time.Time
fileNameToLastModTimeMutex *sync.Mutex
fileNameToError map[string]struct{}
fileNameToErrorMutex *sync.Mutex
fileNameToOutput map[string]generator.GeneratorOutput
fileNameToOutputMutex *sync.Mutex
devMode bool
hashes map[string][sha256.Size]byte
hashesMutex *sync.Mutex
genOpts []generator.GenerateOpt
genSourceMapVis bool
Errors []error
keepOrphanedFiles bool
writer func(string, []byte) error
lazy bool
}
type GenerateResult struct {
// Updated indicates that the file was updated.
Updated bool
// GoUpdated indicates that Go expressions were updated.
GoUpdated bool
// TextUpdated indicates that text literals were updated.
TextUpdated bool
}
func (h *FSEventHandler) HandleEvent(ctx context.Context, event fsnotify.Event) (result GenerateResult, err error) {
// Handle _templ.go files.
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.go") {
_, err = os.Stat(strings.TrimSuffix(event.Name, "_templ.go") + ".templ")
if !os.IsNotExist(err) {
return GenerateResult{}, err
}
// File is orphaned.
if h.keepOrphanedFiles {
return GenerateResult{}, nil
}
h.Log.Debug("Deleting orphaned Go file", slog.String("file", event.Name))
if err = os.Remove(event.Name); err != nil {
h.Log.Warn("Failed to remove orphaned file", slog.Any("error", err))
}
return GenerateResult{Updated: true, GoUpdated: true, TextUpdated: false}, nil
}
// Handle _templ.txt files.
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.txt") {
if h.devMode {
// Don't delete the file in dev mode, ignore changes to it, since the .templ file
// must have been updated in order to trigger a change in the _templ.txt file.
return GenerateResult{Updated: false, GoUpdated: false, TextUpdated: false}, nil
}
h.Log.Debug("Deleting watch mode file", slog.String("file", event.Name))
if err = os.Remove(event.Name); err != nil {
h.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
return GenerateResult{}, nil
}
return GenerateResult{}, nil
}
// If the file hasn't been updated since the last time we processed it, ignore it.
lastModTime, updatedModTime := h.UpsertLastModTime(event.Name)
if !updatedModTime {
h.Log.Debug("Skipping file because it wasn't updated", slog.String("file", event.Name))
return GenerateResult{}, nil
}
// Process anything that isn't a templ file.
if !strings.HasSuffix(event.Name, ".templ") {
// If it's a Go file, mark it as updated.
if strings.HasSuffix(event.Name, ".go") {
result.GoUpdated = true
}
result.Updated = true
return result, nil
}
// Handle templ files.
// If the go file is newer than the templ file, skip generation, because it's up-to-date.
if h.lazy && goFileIsUpToDate(event.Name, lastModTime) {
h.Log.Debug("Skipping file because the Go file is up-to-date", slog.String("file", event.Name))
return GenerateResult{}, nil
}
// Start a processor.
start := time.Now()
var diag []parser.Diagnostic
result, diag, err = h.generate(ctx, event.Name)
if err != nil {
h.SetError(event.Name, true)
return result, fmt.Errorf("failed to generate code for %q: %w", event.Name, err)
}
if len(diag) > 0 {
for _, d := range diag {
h.Log.Warn(d.Message,
slog.String("from", fmt.Sprintf("%d:%d", d.Range.From.Line, d.Range.From.Col)),
slog.String("to", fmt.Sprintf("%d:%d", d.Range.To.Line, d.Range.To.Col)),
)
}
return result, nil
}
if errorCleared, errorCount := h.SetError(event.Name, false); errorCleared {
h.Log.Info("Error cleared", slog.String("file", event.Name), slog.Int("errors", errorCount))
}
h.Log.Debug("Generated code", slog.String("file", event.Name), slog.Duration("in", time.Since(start)))
return result, nil
}
func goFileIsUpToDate(templFileName string, templFileLastMod time.Time) (upToDate bool) {
goFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ.go"
goFileInfo, err := os.Stat(goFileName)
if err != nil {
return false
}
return goFileInfo.ModTime().After(templFileLastMod)
}
func (h *FSEventHandler) SetError(fileName string, hasError bool) (previouslyHadError bool, errorCount int) {
h.fileNameToErrorMutex.Lock()
defer h.fileNameToErrorMutex.Unlock()
_, previouslyHadError = h.fileNameToError[fileName]
delete(h.fileNameToError, fileName)
if hasError {
h.fileNameToError[fileName] = struct{}{}
}
return previouslyHadError, len(h.fileNameToError)
}
func (h *FSEventHandler) UpsertLastModTime(fileName string) (modTime time.Time, updated bool) {
fileInfo, err := os.Stat(fileName)
if err != nil {
return modTime, false
}
h.fileNameToLastModTimeMutex.Lock()
defer h.fileNameToLastModTimeMutex.Unlock()
previousModTime := h.fileNameToLastModTime[fileName]
currentModTime := fileInfo.ModTime()
if !currentModTime.After(previousModTime) {
return currentModTime, false
}
h.fileNameToLastModTime[fileName] = currentModTime
return currentModTime, true
}
func (h *FSEventHandler) UpsertHash(fileName string, hash [sha256.Size]byte) (updated bool) {
h.hashesMutex.Lock()
defer h.hashesMutex.Unlock()
lastHash := h.hashes[fileName]
if lastHash == hash {
return false
}
h.hashes[fileName] = hash
return true
}
// generate Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func (h *FSEventHandler) generate(ctx context.Context, fileName string) (result GenerateResult, diagnostics []parser.Diagnostic, err error) {
t, err := parser.Parse(fileName)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("%s parsing error: %w", fileName, err)
}
targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go"
// Only use relative filenames to the basepath for filenames in runtime error messages.
absFilePath, err := filepath.Abs(fileName)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("failed to get absolute path for %q: %w", fileName, err)
}
relFilePath, err := filepath.Rel(h.dir, absFilePath)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("failed to get relative path for %q: %w", fileName, err)
}
// Convert Windows file paths to Unix-style for consistency.
relFilePath = filepath.ToSlash(relFilePath)
var b bytes.Buffer
generatorOutput, err := generator.Generate(t, &b, append(h.genOpts, generator.WithFileName(relFilePath))...)
if err != nil {
return GenerateResult{}, nil, fmt.Errorf("%s generation error: %w", fileName, err)
}
formattedGoCode, err := format.Source(b.Bytes())
if err != nil {
err = remapErrorList(err, generatorOutput.SourceMap, fileName)
return GenerateResult{}, nil, fmt.Errorf("%s source formatting error %w", fileName, err)
}
// Hash output, and write out the file if the goCodeHash has changed.
goCodeHash := sha256.Sum256(formattedGoCode)
if h.UpsertHash(targetFileName, goCodeHash) {
result.Updated = true
if err = h.writer(targetFileName, formattedGoCode); err != nil {
return result, nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err)
}
}
// Add the txt file if it has changed.
if h.devMode {
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
joined := strings.Join(generatorOutput.Literals, "\n")
txtHash := sha256.Sum256([]byte(joined))
if h.UpsertHash(txtFileName, txtHash) {
result.TextUpdated = true
if err = os.WriteFile(txtFileName, []byte(joined), 0o644); err != nil {
return result, nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err)
}
}
// Check whether the change would require a recompilation to take effect.
h.fileNameToOutputMutex.Lock()
defer h.fileNameToOutputMutex.Unlock()
previous := h.fileNameToOutput[fileName]
if generator.HasChanged(previous, generatorOutput) {
result.GoUpdated = true
}
h.fileNameToOutput[fileName] = generatorOutput
}
parsedDiagnostics, err := parser.Diagnose(t)
if err != nil {
return result, nil, fmt.Errorf("%s diagnostics error: %w", fileName, err)
}
if h.genSourceMapVis {
err = generateSourceMapVisualisation(ctx, fileName, targetFileName, generatorOutput.SourceMap)
}
return result, parsedDiagnostics, err
}
// Takes an error from the formatter and attempts to convert the positions reported in the target file to their positions
// in the source file.
func remapErrorList(err error, sourceMap *parser.SourceMap, fileName string) error {
list, ok := err.(scanner.ErrorList)
if !ok || len(list) == 0 {
return err
}
for i, e := range list {
// The positions in the source map are off by one line because of the package definition.
srcPos, ok := sourceMap.SourcePositionFromTarget(uint32(e.Pos.Line-1), uint32(e.Pos.Column))
if !ok {
continue
}
list[i].Pos = token.Position{
Filename: fileName,
Offset: int(srcPos.Index),
Line: int(srcPos.Line) + 1,
Column: int(srcPos.Col),
}
}
return list
}
func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error {
if err := ctx.Err(); err != nil {
return err
}
var templContents, goContents []byte
var templErr, goErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
templContents, templErr = os.ReadFile(templFileName)
}()
go func() {
defer wg.Done()
goContents, goErr = os.ReadFile(goFileName)
}()
wg.Wait()
if templErr != nil {
return templErr
}
if goErr != nil {
return templErr
}
targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html"
w, err := os.Create(targetFileName)
if err != nil {
return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err)
}
defer w.Close()
b := bufio.NewWriter(w)
defer b.Flush()
return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b)
}

View File

@@ -0,0 +1,23 @@
package generatecmd
type FatalError struct {
Err error
}
func (e FatalError) Error() string {
return e.Err.Error()
}
func (e FatalError) Unwrap() error {
return e.Err
}
func (e FatalError) Is(target error) bool {
_, ok := target.(FatalError)
return ok
}
func (e FatalError) As(target any) bool {
_, ok := target.(*FatalError)
return ok
}

View File

@@ -0,0 +1,39 @@
package generatecmd
import (
"context"
_ "embed"
"log/slog"
_ "net/http/pprof"
)
type Arguments struct {
FileName string
FileWriter FileWriterFunc
Path string
Watch bool
WatchPattern string
OpenBrowser bool
Command string
ProxyBind string
ProxyPort int
Proxy string
NotifyProxy bool
WorkerCount int
GenerateSourceMapVisualisations bool
IncludeVersion bool
IncludeTimestamp bool
// PPROFPort is the port to run the pprof server on.
PPROFPort int
KeepOrphanedFiles bool
Lazy bool
}
func Run(ctx context.Context, log *slog.Logger, args Arguments) (err error) {
g, err := NewGenerate(log, args)
if err != nil {
return err
}
return g.Run(ctx)
}

View File

@@ -0,0 +1,170 @@
package generatecmd
import (
"context"
"io"
"log/slog"
"os"
"path"
"regexp"
"testing"
"time"
"github.com/a-h/templ/cmd/templ/testproject"
"golang.org/x/sync/errgroup"
)
func TestGenerate(t *testing.T) {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
t.Run("can generate a file in place", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
// Run the generate command.
err = Run(context.Background(), log, Arguments{
FileName: path.Join(dir, "templates.templ"),
})
if err != nil {
t.Fatalf("failed to run generate command: %v", err)
}
// Check the templates_templ.go file was created.
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("templates_templ.go was not created: %v", err)
}
})
t.Run("can generate a file in watch mode", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(dir, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
var eg errgroup.Group
eg.Go(func() error {
// Run the generate command.
return Run(ctx, log, Arguments{
Path: dir,
Watch: true,
})
})
// Check the templates_templ.go file was created, with backoff.
for i := 0; i < 5; i++ {
time.Sleep(time.Second * time.Duration(i))
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
if err != nil {
continue
}
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
if err != nil {
continue
}
break
}
if err != nil {
t.Fatalf("template files were not created: %v", err)
}
cancel()
if err := eg.Wait(); err != nil {
t.Fatalf("generate command failed: %v", err)
}
// Check the templates_templ.txt file was removed.
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
if err == nil {
t.Fatalf("templates_templ.txt was not removed")
}
})
}
func TestDefaultWatchPattern(t *testing.T) {
tests := []struct {
name string
input string
matches bool
}{
{
name: "empty file names do not match",
input: "",
matches: false,
},
{
name: "*_templ.txt matches, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\strings_templ.txt`,
matches: true,
},
{
name: "*_templ.txt matches, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/strings_templ.txt",
matches: true,
},
{
name: "*.templ files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.templ`,
matches: true,
},
{
name: "*.templ files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.templ",
matches: true,
},
{
name: "*_templ.go files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates_templ.go`,
matches: true,
},
{
name: "*_templ.go files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates_templ.go",
matches: true,
},
{
name: "*.go files match, Windows",
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.go`,
matches: true,
},
{
name: "*.go files match, Unix",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.go",
matches: true,
},
{
name: "*.css files do not match",
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.css",
matches: false,
},
}
wpRegexp, err := regexp.Compile(defaultWatchPattern)
if err != nil {
t.Fatalf("failed to compile default watch pattern: %v", err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
if wpRegexp.MatchString(test.input) != test.matches {
t.Fatalf("expected match of %q to be %v", test.input, test.matches)
}
})
}
}

View File

@@ -0,0 +1,82 @@
package modcheck
import (
"fmt"
"os"
"path/filepath"
"github.com/a-h/templ"
"golang.org/x/mod/modfile"
"golang.org/x/mod/semver"
)
// WalkUp the directory tree, starting at dir, until we find a directory containing
// a go.mod file.
func WalkUp(dir string) (string, error) {
dir, err := filepath.Abs(dir)
if err != nil {
return "", fmt.Errorf("failed to get absolute path: %w", err)
}
var modFile string
for {
modFile = filepath.Join(dir, "go.mod")
_, err := os.Stat(modFile)
if err != nil && !os.IsNotExist(err) {
return "", fmt.Errorf("failed to stat go.mod file: %w", err)
}
if os.IsNotExist(err) {
// Move up.
prev := dir
dir = filepath.Dir(dir)
if dir == prev {
break
}
continue
}
break
}
// No file found.
if modFile == "" {
return dir, fmt.Errorf("could not find go.mod file")
}
return dir, nil
}
func Check(dir string) error {
dir, err := WalkUp(dir)
if err != nil {
return err
}
// Found a go.mod file.
// Read it and find the templ version.
modFile := filepath.Join(dir, "go.mod")
m, err := os.ReadFile(modFile)
if err != nil {
return fmt.Errorf("failed to read go.mod file: %w", err)
}
mf, err := modfile.Parse(modFile, m, nil)
if err != nil {
return fmt.Errorf("failed to parse go.mod file: %w", err)
}
if mf.Module.Mod.Path == "github.com/a-h/templ" {
// The go.mod file is for templ itself.
return nil
}
for _, r := range mf.Require {
if r.Mod.Path == "github.com/a-h/templ" {
cmp := semver.Compare(r.Mod.Version, templ.Version())
if cmp < 0 {
return fmt.Errorf("generator %v is newer than templ version %v found in go.mod file, consider running `go get -u github.com/a-h/templ` to upgrade", templ.Version(), r.Mod.Version)
}
if cmp > 0 {
return fmt.Errorf("generator %v is older than templ version %v found in go.mod file, consider upgrading templ CLI", templ.Version(), r.Mod.Version)
}
return nil
}
}
return fmt.Errorf("templ not found in go.mod file, run `go get github.com/a-h/templ` to install it")
}

View File

@@ -0,0 +1,47 @@
package modcheck
import (
"testing"
"golang.org/x/mod/modfile"
)
func TestPatchGoVersion(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: "go 1.20",
expected: "1.20",
},
{
input: "go 1.20.123",
expected: "1.20.123",
},
{
input: "go 1.20.1",
expected: "1.20.1",
},
{
input: "go 1.20rc1",
expected: "1.20rc1",
},
{
input: "go 1.15",
expected: "1.15",
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
input := "module github.com/a-h/templ\n\n" + string(test.input) + "\n" + "toolchain go1.27.9\n"
mf, err := modfile.Parse("go.mod", []byte(input), nil)
if err != nil {
t.Fatalf("failed to parse go.mod: %v", err)
}
if test.expected != mf.Go.Version {
t.Errorf("expected %q, got %q", test.expected, mf.Go.Version)
}
})
}
}

View File

@@ -0,0 +1,284 @@
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"html"
"io"
stdlog "log"
"log/slog"
"math"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
"github.com/andybalholm/brotli"
_ "embed"
)
//go:embed script.js
var script string
type Handler struct {
log *slog.Logger
URL string
Target *url.URL
p *httputil.ReverseProxy
sse *sse.Handler
}
func getScriptTag(nonce string) string {
if nonce != "" {
var sb strings.Builder
sb.WriteString(`<script src="/_templ/reload/script.js" nonce="`)
sb.WriteString(html.EscapeString(nonce))
sb.WriteString(`"></script>`)
return sb.String()
}
return `<script src="/_templ/reload/script.js"></script>`
}
func insertScriptTagIntoBody(nonce, body string) (updated string) {
doc, err := goquery.NewDocumentFromReader(strings.NewReader(body))
if err != nil {
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
}
doc.Find("body").AppendHtml(getScriptTag(nonce))
r, err := doc.Html()
if err != nil {
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
}
return r
}
type passthroughWriteCloser struct {
io.Writer
}
func (pwc passthroughWriteCloser) Close() error {
return nil
}
const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."
func (h *Handler) modifyResponse(r *http.Response) error {
log := h.log.With(slog.String("url", r.Request.URL.String()))
if r.Header.Get("templ-skip-modify") == "true" {
log.Debug("Skipping response modification because templ-skip-modify header is set")
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
log.Debug("Skipping response modification because content type is not text/html", slog.String("content-type", contentType))
return nil
}
// Set up readers and writers.
newReader := func(in io.Reader) (out io.Reader, err error) {
return in, nil
}
newWriter := func(out io.Writer) io.WriteCloser {
return passthroughWriteCloser{out}
}
switch r.Header.Get("Content-Encoding") {
case "gzip":
newReader = func(in io.Reader) (out io.Reader, err error) {
return gzip.NewReader(in)
}
newWriter = func(out io.Writer) io.WriteCloser {
return gzip.NewWriter(out)
}
case "br":
newReader = func(in io.Reader) (out io.Reader, err error) {
return brotli.NewReader(in), nil
}
newWriter = func(out io.Writer) io.WriteCloser {
return brotli.NewWriter(out)
}
case "":
log.Debug("No content encoding header found")
default:
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
}
// Read the encoded body.
encr, err := newReader(r.Body)
if err != nil {
return err
}
defer r.Body.Close()
body, err := io.ReadAll(encr)
if err != nil {
return err
}
// Update it.
csp := r.Header.Get("Content-Security-Policy")
updated := insertScriptTagIntoBody(parseNonce(csp), string(body))
if log.Enabled(r.Request.Context(), slog.LevelDebug) {
if len(updated) == len(body) {
log.Debug("Reload script not inserted")
} else {
log.Debug("Reload script inserted")
}
}
// Encode the response.
var buf bytes.Buffer
encw := newWriter(&buf)
_, err = encw.Write([]byte(updated))
if err != nil {
return err
}
err = encw.Close()
if err != nil {
return err
}
// Update the response.
r.Body = io.NopCloser(&buf)
r.ContentLength = int64(buf.Len())
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
return nil
}
func parseNonce(csp string) (nonce string) {
outer:
for _, rawDirective := range strings.Split(csp, ";") {
parts := strings.Fields(rawDirective)
if len(parts) < 2 {
continue
}
if parts[0] != "script-src" {
continue
}
for _, source := range parts[1:] {
source = strings.TrimPrefix(source, "'")
source = strings.TrimSuffix(source, "'")
if strings.HasPrefix(source, "nonce-") {
nonce = source[6:]
break outer
}
}
}
return nonce
}
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
p := httputil.NewSingleHostReverseProxy(target)
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
p.Transport = &roundTripper{
maxRetries: 20,
initialDelay: 100 * time.Millisecond,
backoffExponent: 1.5,
}
h = &Handler{
log: log,
URL: fmt.Sprintf("http://%s:%d", bind, port),
Target: target,
p: p,
sse: sse.New(),
}
p.ModifyResponse = h.modifyResponse
return h
}
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/_templ/reload/script.js" {
// Provides a script that reloads the page.
w.Header().Add("Content-Type", "text/javascript")
_, err := io.WriteString(w, script)
if err != nil {
fmt.Printf("failed to write script: %v\n", err)
}
return
}
if r.URL.Path == "/_templ/reload/events" {
switch r.Method {
case http.MethodGet:
// Provides a list of messages including a reload message.
p.sse.ServeHTTP(w, r)
return
case http.MethodPost:
// Send a reload message to all connected clients.
p.sse.Send("message", "reload")
return
}
http.Error(w, "only GET or POST method allowed", http.StatusMethodNotAllowed)
return
}
p.p.ServeHTTP(w, r)
}
func (p *Handler) SendSSE(eventType string, data string) {
p.sse.Send(eventType, data)
}
type roundTripper struct {
maxRetries int
initialDelay time.Duration
backoffExponent float64
}
func (rt *roundTripper) setShouldSkipResponseModificationHeader(r *http.Request, resp *http.Response) {
// Instruct the modifyResponse function to skip modifying the response if the
// HTTP request has come from HTMX.
if r.Header.Get("HX-Request") != "true" {
return
}
resp.Header.Set("templ-skip-modify", "true")
}
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// Read and buffer the body.
var bodyBytes []byte
if r.Body != nil && r.Body != http.NoBody {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body.Close()
}
// Retry logic.
var resp *http.Response
var err error
for retries := 0; retries < rt.maxRetries; retries++ {
// Clone the request and set the body.
req := r.Clone(r.Context())
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Execute the request.
resp, err = http.DefaultTransport.RoundTrip(req)
if err != nil {
time.Sleep(rt.initialDelay * time.Duration(math.Pow(rt.backoffExponent, float64(retries))))
continue
}
rt.setShouldSkipResponseModificationHeader(r, resp)
return resp, nil
}
return nil, fmt.Errorf("max retries reached: %q", r.URL.String())
}
func NotifyProxy(host string, port int) error {
urlStr := fmt.Sprintf("http://%s:%d/_templ/reload/events", host, port)
req, err := http.NewRequest(http.MethodPost, urlStr, nil)
if err != nil {
return err
}
_, err = http.DefaultClient.Do(req)
return err
}

View File

@@ -0,0 +1,627 @@
package proxy
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/andybalholm/brotli"
"github.com/google/go-cmp/cmp"
)
func TestRoundTripper(t *testing.T) {
t.Run("if the HX-Request header is present, set the templ-skip-modify header on the response", func(t *testing.T) {
rt := &roundTripper{}
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("unexpected error creating request: %v", err)
}
req.Header.Set("HX-Request", "true")
resp := &http.Response{Header: make(http.Header)}
rt.setShouldSkipResponseModificationHeader(req, resp)
if resp.Header.Get("templ-skip-modify") != "true" {
t.Errorf("expected templ-skip-modify header to be true, got %v", resp.Header.Get("templ-skip-modify"))
}
})
t.Run("if the HX-Request header is not present, do not set the templ-skip-modify header on the response", func(t *testing.T) {
rt := &roundTripper{}
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("unexpected error creating request: %v", err)
}
resp := &http.Response{Header: make(http.Header)}
rt.setShouldSkipResponseModificationHeader(req, resp)
if resp.Header.Get("templ-skip-modify") != "" {
t.Errorf("expected templ-skip-modify header to be empty, got %v", resp.Header.Get("templ-skip-modify"))
}
})
}
func TestProxy(t *testing.T) {
t.Run("plain: non-html content is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Content-Length", "16")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "16" {
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: if the response contains templ-skip-modify header, it is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`Hello`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html")
r.Header.Set("Content-Length", "5")
r.Header.Set("templ-skip-modify", "true")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "5" {
t.Errorf("expected content length to be 5, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`Hello`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
expectedString := insertScriptTagIntoBody("", `<html><body></body></html>`)
if !strings.Contains(expectedString, getScriptTag("")) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted with nonce", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
const nonce = "this-is-the-nonce"
r.Header.Set("Content-Security-Policy", fmt.Sprintf("script-src 'nonce-%s'", nonce))
expectedString := insertScriptTagIntoBody(nonce, `<html><body></body></html>`)
if !strings.Contains(expectedString, getScriptTag(nonce)) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`<html><body><script>console.log("<body></body>")</script></body></html>`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Length", "26")
expectedString := insertScriptTagIntoBody("", `<html><body><script>console.log("<body></body>")</script></body></html>`)
if !strings.Contains(expectedString, getScriptTag("")) {
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
}
if !strings.Contains(expectedString, `console.log("<body></body>")`) {
t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString)
}
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("gzip: non-html content is not modified", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "application/json")
// It's not actually gzipped here, but it doesn't matter, it shouldn't get that far.
r.Header.Set("Content-Encoding", "gzip")
// Similarly, this is not the actual length of the gzipped content.
r.Header.Set("Content-Length", "16")
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != "16" {
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("gzip: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
_, err := gzw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
gzw.Close()
expectedString := insertScriptTagIntoBody("", body)
var expectedBytes bytes.Buffer
gzw = gzip.NewWriter(&expectedBytes)
_, err = gzw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
gzw.Close()
expectedLength := len(expectedBytes.Bytes())
r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "gzip")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}
gr, err := gzip.NewReader(r.Body)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
actualBody, err := io.ReadAll(gr)
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
brw := brotli.NewWriter(&buf)
_, err := brw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedString := insertScriptTagIntoBody("", body)
var expectedBytes bytes.Buffer
brw = brotli.NewWriter(&expectedBytes)
_, err = brw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedLength := len(expectedBytes.Bytes())
r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "br")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}
actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
// Arrange 1: create a test proxy server.
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
dummyServer := httptest.NewServer(http.HandlerFunc(dummyHandler))
defer dummyServer.Close()
u, err := url.Parse(dummyServer.URL)
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
handler := New(log, "0.0.0.0", 0, u)
proxyServer := httptest.NewServer(handler)
defer proxyServer.Close()
u2, err := url.Parse(proxyServer.URL)
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
port, err := strconv.Atoi(u2.Port())
if err != nil {
t.Fatalf("unexpected error parsing port: %v", err)
}
// Arrange 2: start a goroutine to listen for sse events.
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
errChan := make(chan error)
sseRespCh := make(chan string)
sseListening := make(chan bool) // Coordination channel that ensures the SSE listener is started before notifying the proxy.
go func() {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/_templ/reload/events", proxyServer.URL), nil)
if err != nil {
errChan <- err
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
errChan <- err
return
}
defer resp.Body.Close()
sseListening <- true
lines := []string{}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
lines = append(lines, scanner.Text())
if scanner.Text() == "data: reload" {
sseRespCh <- strings.Join(lines, "\n")
return
}
}
err = scanner.Err()
if err != nil {
errChan <- err
return
}
}()
// Act: notify the proxy.
select { // Either SSE is listening or an error occurred.
case <-sseListening:
err = NotifyProxy(u2.Hostname(), port)
if err != nil {
t.Fatalf("unexpected error notifying proxy: %v", err)
}
case err := <-errChan:
if err == nil {
t.Fatalf("unexpected sse response: %v", err)
}
}
// Assert.
select { // Either SSE has a expected response or an error or timeout occurred.
case resp := <-sseRespCh:
if !strings.Contains(resp, "event: message\ndata: reload") {
t.Errorf("expected sse reload event to be received, got: %q", resp)
}
case err := <-errChan:
if err == nil {
t.Fatalf("unexpected sse response: %v", err)
}
case <-ctx.Done():
t.Fatalf("timeout waiting for sse response")
}
})
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
Header: make(http.Header),
Request: &http.Request{
URL: &url.URL{
Scheme: "http",
Host: "example.com",
},
},
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "weird-encoding")
// Act
lh := newTestLogHandler(slog.LevelInfo)
log := slog.New(lh)
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Assert
if len(lh.records) != 1 {
var sb strings.Builder
for _, record := range lh.records {
sb.WriteString(record.Message)
sb.WriteString("\n")
}
t.Fatalf("expected 1 log entry, but got %d: \n%s", len(lh.records), sb.String())
}
record := lh.records[0]
if record.Message != unsupportedContentEncoding {
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
}
if record.Level != slog.LevelWarn {
t.Errorf("expected warning, got level %v", record.Level)
}
})
}
func newTestLogHandler(level slog.Level) *testLogHandler {
return &testLogHandler{
m: new(sync.Mutex),
records: nil,
level: level,
}
}
type testLogHandler struct {
m *sync.Mutex
records []slog.Record
level slog.Level
}
func (h *testLogHandler) Enabled(ctx context.Context, l slog.Level) bool {
return l >= h.level
}
func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
h.m.Lock()
defer h.m.Unlock()
if r.Level < h.level {
return nil
}
h.records = append(h.records, r)
return nil
}
func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h
}
func (h *testLogHandler) WithGroup(name string) slog.Handler {
return h
}
func TestParseNonce(t *testing.T) {
for _, tc := range []struct {
name string
csp string
expected string
}{
{
name: "empty csp",
csp: "",
expected: "",
},
{
name: "simple csp",
csp: "script-src 'nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C'",
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
},
{
name: "simple csp without single quote",
csp: "script-src nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
},
{
name: "complete csp",
csp: "default-src 'self'; frame-ancestors 'self'; form-action 'self'; script-src 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC';",
expected: "4VOtk0Uo1l7pwtC",
},
{
name: "mdn example 1",
csp: "default-src 'self'",
expected: "",
},
{
name: "mdn example 2",
csp: "default-src 'self' *.trusted.com",
expected: "",
},
{
name: "mdn example 3",
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com",
expected: "",
},
{
name: "mdn example 3 multiple sources",
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com foo.com 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC'",
expected: "4VOtk0Uo1l7pwtC",
},
} {
t.Run(tc.name, func(t *testing.T) {
nonce := parseNonce(tc.csp)
if nonce != tc.expected {
t.Errorf("expected nonce to be %s, but got %s", tc.expected, nonce)
}
})
}
}

View File

@@ -0,0 +1,10 @@
(function() {
let templ_reloadSrc = window.templ_reloadSrc || new EventSource("/_templ/reload/events");
templ_reloadSrc.onmessage = (event) => {
if (event && event.data === "reload") {
window.location.reload();
}
};
window.templ_reloadSrc = templ_reloadSrc;
window.onbeforeunload = () => window.templ_reloadSrc.close();
})();

View File

@@ -0,0 +1,108 @@
package run_test
import (
"context"
"embed"
"io"
"net/http"
"os"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/a-h/templ/cmd/templ/generatecmd/run"
)
//go:embed testprogram/*
var testprogram embed.FS
func TestGoRun(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}
// Copy testprogram to a temporary directory.
dir, err := os.MkdirTemp("", "testprogram")
if err != nil {
t.Fatalf("failed to make test dir: %v", err)
}
files, err := testprogram.ReadDir("testprogram")
if err != nil {
t.Fatalf("failed to read embedded dir: %v", err)
}
for _, file := range files {
srcFileName := "testprogram/" + file.Name()
srcData, err := testprogram.ReadFile(srcFileName)
if err != nil {
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
}
tgtFileName := filepath.Join(dir, file.Name())
tgtFile, err := os.Create(tgtFileName)
if err != nil {
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
}
defer tgtFile.Close()
if _, err := tgtFile.Write(srcData); err != nil {
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
}
}
// Rename the go.mod.embed file to go.mod.
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
t.Fatalf("failed to rename go.mod.embed: %v", err)
}
tests := []struct {
name string
cmd string
}{
{
name: "Well behaved programs get shut down",
cmd: "go run .",
},
{
name: "Badly behaved programs get shut down",
cmd: "go run . -badly-behaved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cmd, err := run.Run(ctx, dir, tt.cmd)
if err != nil {
t.Fatalf("failed to run program: %v", err)
}
time.Sleep(1 * time.Second)
pid := cmd.Process.Pid
if err := run.KillAll(); err != nil {
t.Fatalf("failed to kill all: %v", err)
}
// Check the parent process is no longer running.
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
t.Fatalf("process %d is still running", pid)
}
// Check that the child was stopped.
body, err := readResponse("http://localhost:7777")
if err == nil {
t.Fatalf("child process is still running: %s", body)
}
})
}
}
func readResponse(url string) (body string, err error) {
resp, err := http.Get(url)
if err != nil {
return body, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return body, err
}
return string(b), nil
}

View File

@@ -0,0 +1,84 @@
//go:build unix
package run
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)
var (
m = &sync.Mutex{}
running = map[string]*exec.Cmd{}
)
func KillAll() (err error) {
m.Lock()
defer m.Unlock()
var errs []error
for _, cmd := range running {
if err := kill(cmd); err != nil {
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
}
}
running = map[string]*exec.Cmd{}
return errors.Join(errs...)
}
func kill(cmd *exec.Cmd) (err error) {
errs := make([]error, 4)
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
errs[2] = ignoreExited(cmd.Wait())
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
return errors.Join(errs...)
}
func ignoreExited(err error) error {
if errors.Is(err, syscall.ESRCH) {
return nil
}
// Ignore *exec.ExitError
if _, ok := err.(*exec.ExitError); ok {
return nil
}
return err
}
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
if err := kill(cmd); err != nil {
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
}
delete(running, input)
}
parts := strings.Fields(input)
executable := parts[0]
args := []string{}
if len(parts) > 1 {
args = append(args, parts[1:]...)
}
cmd = exec.CommandContext(ctx, executable, args...)
// Wait for the process to finish gracefully before termination.
cmd.WaitDelay = time.Second * 3
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
running[input] = cmd
err = cmd.Start()
return
}

View File

@@ -0,0 +1,69 @@
//go:build windows
package run
import (
"context"
"os"
"os/exec"
"strconv"
"strings"
"sync"
)
var m = &sync.Mutex{}
var running = map[string]*exec.Cmd{}
func KillAll() (err error) {
m.Lock()
defer m.Unlock()
for _, cmd := range running {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
err := kill.Run()
if err != nil {
return err
}
}
running = map[string]*exec.Cmd{}
return
}
func Stop(cmd *exec.Cmd) (err error) {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
return kill.Run()
}
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
kill.Stderr = os.Stderr
kill.Stdout = os.Stdout
err := kill.Run()
if err != nil {
return cmd, err
}
delete(running, input)
}
parts := strings.Fields(input)
executable := parts[0]
args := []string{}
if len(parts) > 1 {
args = append(args, parts[1:]...)
}
cmd = exec.Command(executable, args...)
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
running[input] = cmd
err = cmd.Start()
return
}

View File

@@ -0,0 +1,3 @@
module testprogram
go 1.23

View File

@@ -0,0 +1,63 @@
package main
import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// This is a test program. It is used only to test the behaviour of the run package.
// The run package is supposed to be able to run and stop programs. Those programs may start
// child processes, which should also be stopped when the parent program is stopped.
// For example, running `go run .` will compile an executable and run it.
// So, this program does nothing. It just waits for a signal to stop.
// In "Well behaved" mode, the program will stop when it receives a signal.
// In "Badly behaved" mode, the program will ignore the signal and continue running.
// The run package should be able to stop the program in both cases.
var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")
func main() {
flag.Parse()
mode := "Well behaved"
if *badlyBehavedFlag {
mode = "Badly behaved"
}
fmt.Printf("%s process %d started.\n", mode, os.Getpid())
// Start a web server on a known port so that we can check that this process is
// not running, when it's been started as a child process, and we don't know
// its pid.
go func() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d", os.Getpid())
})
err := http.ListenAndServe("127.0.0.1:7777", nil)
if err != nil {
fmt.Printf("Error running web server: %v\n", err)
}
}()
sigs := make(chan os.Signal, 1)
if !*badlyBehavedFlag {
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
}
for {
select {
case <-sigs:
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
return
case <-time.After(1 * time.Second):
fmt.Printf("Process %d still running...\n", os.Getpid())
}
}
}

View File

@@ -0,0 +1,84 @@
package sse
import (
_ "embed"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
)
func New() *Handler {
return &Handler{
m: new(sync.Mutex),
requests: map[int64]chan event{},
}
}
type Handler struct {
m *sync.Mutex
counter int64
requests map[int64]chan event
}
type event struct {
Type string
Data string
}
// Send an event to all connected clients.
func (s *Handler) Send(eventType string, data string) {
s.m.Lock()
defer s.m.Unlock()
for _, f := range s.requests {
f := f
go func(f chan event) {
f <- event{
Type: eventType,
Data: data,
}
}(f)
}
}
func (s *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
id := atomic.AddInt64(&s.counter, 1)
s.m.Lock()
events := make(chan event)
s.requests[id] = events
s.m.Unlock()
defer func() {
s.m.Lock()
defer s.m.Unlock()
delete(s.requests, id)
close(events)
}()
timer := time.NewTimer(0)
loop:
for {
select {
case <-timer.C:
if _, err := fmt.Fprintf(w, "event: message\ndata: ping\n\n"); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
timer.Reset(time.Second * 5)
case e := <-events:
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Type, e.Data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
case <-r.Context().Done():
break loop
}
w.(http.Flusher).Flush()
}
}

View File

@@ -0,0 +1,52 @@
package symlink
import (
"context"
"io"
"log/slog"
"os"
"path"
"testing"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/cmd/templ/testproject"
)
func TestSymlink(t *testing.T) {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
t.Run("can generate if root is symlink", func(t *testing.T) {
// templ generate -f templates.templ
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
symlinkPath := dir + "-symlink"
err = os.Symlink(dir, symlinkPath)
if err != nil {
t.Fatalf("failed to create dir symlink: %v", err)
}
defer os.Remove(symlinkPath)
// Delete the templates_templ.go file to ensure it is generated.
err = os.Remove(path.Join(symlinkPath, "templates_templ.go"))
if err != nil {
t.Fatalf("failed to remove templates_templ.go: %v", err)
}
// Run the generate command.
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
Path: symlinkPath,
})
if err != nil {
t.Fatalf("failed to run generate command: %v", err)
}
// Check the templates_templ.go file was created.
_, err = os.Stat(path.Join(symlinkPath, "templates_templ.go"))
if err != nil {
t.Fatalf("templates_templ.go was not created: %v", err)
}
})
}

View File

@@ -0,0 +1,101 @@
package testeventhandler
import (
"context"
"errors"
"fmt"
"go/scanner"
"go/token"
"io"
"log/slog"
"os"
"testing"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/generator"
"github.com/fsnotify/fsnotify"
"github.com/google/go-cmp/cmp"
)
func TestErrorLocationMapping(t *testing.T) {
tests := []struct {
name string
rawFileName string
errorPositions []token.Position
}{
{
name: "single error outputs location in srcFile",
rawFileName: "single_error.templ.error",
errorPositions: []token.Position{
{Offset: 46, Line: 3, Column: 20},
},
},
{
name: "multiple errors all output locations in srcFile",
rawFileName: "multiple_errors.templ.error",
errorPositions: []token.Position{
{Offset: 41, Line: 3, Column: 15},
{Offset: 101, Line: 7, Column: 22},
{Offset: 126, Line: 10, Column: 1},
},
},
}
slog := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
var fw generatecmd.FileWriterFunc
fseh := generatecmd.NewFSEventHandler(slog, ".", false, []generator.GenerateOpt{}, false, false, fw, false)
for _, test := range tests {
// The raw files cannot end in .templ because they will cause the generator to fail. Instead,
// we create a tmp file that ends in .templ only for the duration of the test.
rawFile, err := os.Open(test.rawFileName)
if err != nil {
t.Errorf("%s: Failed to open file %s: %v", test.name, test.rawFileName, err)
break
}
file, err := os.CreateTemp("", fmt.Sprintf("*%s.templ", test.rawFileName))
if err != nil {
t.Errorf("%s: Failed to create a tmp file at %s: %v", test.name, file.Name(), err)
break
}
defer os.Remove(file.Name())
if _, err = io.Copy(file, rawFile); err != nil {
t.Errorf("%s: Failed to copy contents from raw file %s to tmp %s: %v", test.name, test.rawFileName, file.Name(), err)
}
event := fsnotify.Event{Name: file.Name(), Op: fsnotify.Write}
_, err = fseh.HandleEvent(context.Background(), event)
if err == nil {
t.Errorf("%s: no error was thrown", test.name)
break
}
list, ok := err.(scanner.ErrorList)
for !ok {
err = errors.Unwrap(err)
if err == nil {
t.Errorf("%s: reached end of error wrapping before finding an ErrorList", test.name)
break
} else {
list, ok = err.(scanner.ErrorList)
}
}
if !ok {
break
}
if len(list) != len(test.errorPositions) {
t.Errorf("%s: expected %d errors but got %d", test.name, len(test.errorPositions), len(list))
break
}
for i, err := range list {
test.errorPositions[i].Filename = file.Name()
diff := cmp.Diff(test.errorPositions[i], err.Pos)
if diff != "" {
t.Error(diff)
t.Error("expected:")
t.Error(test.errorPositions[i])
t.Error("actual:")
t.Error(err.Pos)
}
}
}
}

View File

@@ -0,0 +1,10 @@
package testeventhandler
func invalid(a: string) string {
return "foo"
}
templ multipleError(a: string) {
<div/>
}
l

View File

@@ -0,0 +1,5 @@
package testeventhandler
templ singleError(a: string) {
<div/>
}

View File

@@ -0,0 +1,485 @@
package testwatch
import (
"bufio"
"bytes"
"context"
"embed"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
)
//go:embed testdata/*
var testdata embed.FS
func createTestProject(moduleRoot string) (dir string, err error) {
dir, err = os.MkdirTemp("", "templ_watch_test_*")
if err != nil {
return dir, fmt.Errorf("failed to make test dir: %w", err)
}
files, err := testdata.ReadDir("testdata")
if err != nil {
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
}
for _, file := range files {
src := filepath.Join("testdata", file.Name())
data, err := testdata.ReadFile(src)
if err != nil {
return dir, fmt.Errorf("failed to read file: %w", err)
}
target := filepath.Join(dir, file.Name())
if file.Name() == "go.mod.embed" {
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
target = filepath.Join(dir, "go.mod")
}
err = os.WriteFile(target, data, 0660)
if err != nil {
return dir, fmt.Errorf("failed to copy file: %w", err)
}
}
return dir, nil
}
func replaceInFile(name, src, tgt string) error {
data, err := os.ReadFile(name)
if err != nil {
return err
}
updated := strings.Replace(string(data), src, tgt, -1)
return os.WriteFile(name, []byte(updated), 0660)
}
func getPort() (port int, err error) {
var a *net.TCPAddr
if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
}
return
}
func getHTML(url string) (doc *goquery.Document, err error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to get %q: %w", url, err)
}
return goquery.NewDocumentFromReader(resp.Body)
}
func TestCanAccessDirect(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Assert.
doc, err := getHTML(args.AppURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
countText := doc.Find(`div[data-testid="count"]`).Text()
actualCount, err := strconv.Atoi(countText)
if err != nil {
t.Fatalf("got count %q instead of integer", countText)
}
if actualCount < 1 {
t.Errorf("expected count >= 1, got %d", actualCount)
}
}
func TestCanAccessViaProxy(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Assert.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
countText := doc.Find(`div[data-testid="count"]`).Text()
actualCount, err := strconv.Atoi(countText)
if err != nil {
t.Fatalf("got count %q instead of integer", countText)
}
if actualCount < 1 {
t.Errorf("expected count >= 1, got %d", actualCount)
}
}
type Event struct {
Type string
Data string
}
func readSSE(ctx context.Context, url string, sse chan<- Event) (err error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
var e Event
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
sse <- e
e = Event{}
continue
}
if strings.HasPrefix(line, "event: ") {
e.Type = line[len("event: "):]
}
if strings.HasPrefix(line, "data: ") {
e.Data = line[len("data: "):]
}
}
return scanner.Err()
}
func TestFileModificationsResultInSSEWithGzip(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Start the SSE check.
events := make(chan Event)
var eventsErr error
go func() {
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
}()
// Assert data is expected.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
t.Errorf("expected %q, got %q", "Original", text)
}
// Change file.
templFile := filepath.Join(args.AppDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification">Updated</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
// Give the filesystem watcher a few seconds.
var reloadCount int
loop:
for {
select {
case event := <-events:
if event.Data == "reload" {
reloadCount++
break loop
}
case <-time.After(time.Second * 5):
break loop
}
}
if reloadCount == 0 {
t.Error("failed to receive SSE about update after 5 seconds")
}
// Check to see if there were any errors.
if eventsErr != nil {
t.Errorf("error reading events: %v", err)
}
// See results in browser immediately.
doc, err = getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
t.Errorf("expected %q, got %q", "Updated", text)
}
}
func TestFileModificationsResultInSSE(t *testing.T) {
if testing.Short() {
return
}
args, teardown, err := Setup(false)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
// Start the SSE check.
events := make(chan Event)
var eventsErr error
go func() {
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
}()
// Assert data is expected.
doc, err := getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
t.Errorf("expected %q, got %q", "Original", text)
}
// Change file.
templFile := filepath.Join(args.AppDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification">Updated</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
// Give the filesystem watcher a few seconds.
var reloadCount int
loop:
for {
select {
case event := <-events:
if event.Data == "reload" {
reloadCount++
break loop
}
case <-time.After(time.Second * 5):
break loop
}
}
if reloadCount == 0 {
t.Error("failed to receive SSE about update after 5 seconds")
}
// Check to see if there were any errors.
if eventsErr != nil {
t.Errorf("error reading events: %v", err)
}
// See results in browser immediately.
doc, err = getHTML(args.ProxyURL)
if err != nil {
t.Fatalf("failed to read HTML: %v", err)
}
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
t.Errorf("expected %q, got %q", "Updated", text)
}
}
func NewTestArgs(modRoot, appDir string, appPort int, proxyBind string, proxyPort int) TestArgs {
return TestArgs{
ModRoot: modRoot,
AppDir: appDir,
AppPort: appPort,
AppURL: fmt.Sprintf("http://localhost:%d", appPort),
ProxyBind: proxyBind,
ProxyPort: proxyPort,
ProxyURL: fmt.Sprintf("http://%s:%d", proxyBind, proxyPort),
}
}
type TestArgs struct {
ModRoot string
AppDir string
AppPort int
AppURL string
ProxyBind string
ProxyPort int
ProxyURL string
}
func Setup(gzipEncoding bool) (args TestArgs, teardown func(t *testing.T), err error) {
wd, err := os.Getwd()
if err != nil {
return args, teardown, fmt.Errorf("could not find working dir: %w", err)
}
moduleRoot, err := modcheck.WalkUp(wd)
if err != nil {
return args, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
}
appDir, err := createTestProject(moduleRoot)
if err != nil {
return args, teardown, fmt.Errorf("failed to create test project: %v", err)
}
appPort, err := getPort()
if err != nil {
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
}
proxyPort, err := getPort()
if err != nil {
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
}
proxyBind := "localhost"
args = NewTestArgs(moduleRoot, appDir, appPort, proxyBind, proxyPort)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
var cmdErr error
wg.Add(1)
go func() {
defer wg.Done()
command := fmt.Sprintf("go run . -port %d", args.AppPort)
if gzipEncoding {
command += " -gzip true"
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
cmdErr = generatecmd.Run(ctx, log, generatecmd.Arguments{
Path: appDir,
Watch: true,
OpenBrowser: false,
Command: command,
ProxyBind: proxyBind,
ProxyPort: proxyPort,
Proxy: args.AppURL,
NotifyProxy: false,
WorkerCount: 0,
GenerateSourceMapVisualisations: false,
IncludeVersion: false,
IncludeTimestamp: false,
PPROFPort: 0,
KeepOrphanedFiles: false,
})
}()
// Wait for server to start.
if err = waitForURL(args.AppURL); err != nil {
cancel()
wg.Wait()
return args, teardown, fmt.Errorf("failed to start app server, command error %v: %w", cmdErr, err)
}
if err = waitForURL(args.ProxyURL); err != nil {
cancel()
wg.Wait()
return args, teardown, fmt.Errorf("failed to start proxy server, command error %v: %w", cmdErr, err)
}
// Wait for exit.
teardown = func(t *testing.T) {
cancel()
wg.Wait()
if cmdErr != nil {
t.Errorf("failed to run generate cmd: %v", err)
}
if err = os.RemoveAll(appDir); err != nil {
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
}
}
return args, teardown, err
}
func waitForURL(url string) (err error) {
var tries int
for {
time.Sleep(time.Second)
if tries > 20 {
return err
}
tries++
var resp *http.Response
resp, err = http.Get(url)
if err != nil {
fmt.Printf("failed to get %q: %v\n", url, err)
continue
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("failed to get %q: %v\n", url, err)
err = fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
continue
}
return nil
}
}
func TestGenerateReturnsErrors(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatalf("could not find working dir: %v", err)
}
moduleRoot, err := modcheck.WalkUp(wd)
if err != nil {
t.Fatalf("could not find local templ go.mod file: %v", err)
}
appDir, err := createTestProject(moduleRoot)
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer func() {
if err = os.RemoveAll(appDir); err != nil {
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
}
}()
// Break the HTML.
templFile := filepath.Join(appDir, "templates.templ")
err = replaceInFile(templFile,
`<div data-testid="modification">Original</div>`,
`<div data-testid="modification" -unclosed div-</div>`)
if err != nil {
t.Errorf("failed to replace text in file: %v", err)
}
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
// Run.
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
Path: appDir,
Watch: false,
IncludeVersion: false,
IncludeTimestamp: false,
KeepOrphanedFiles: false,
})
if err == nil {
t.Errorf("expected generation error, got %v", err)
}
}

View File

@@ -0,0 +1,7 @@
module templ/testproject
go 1.23
require github.com/a-h/templ v0.2.513 // indirect
replace github.com/a-h/templ => {moduleRoot}

View File

@@ -0,0 +1,2 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=

View File

@@ -0,0 +1,81 @@
package main
import (
"bytes"
"compress/gzip"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"strconv"
"github.com/a-h/templ"
)
type GzipResponseWriter struct {
w http.ResponseWriter
}
func (w *GzipResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
defer gzw.Close()
_, err := gzw.Write(b)
if err != nil {
return 0, err
}
err = gzw.Close()
if err != nil {
return 0, err
}
w.w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
return w.w.Write(buf.Bytes())
}
func (w *GzipResponseWriter) WriteHeader(statusCode int) {
w.w.WriteHeader(statusCode)
}
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
var useGzip = flag.Bool("gzip", false, "Toggle gzip encoding")
func main() {
flag.Parse()
if *flagPort == 0 {
fmt.Println("missing port flag")
os.Exit(1)
}
var count int
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if useGzip != nil && *useGzip {
w.Header().Set("Content-Encoding", "gzip")
w = &GzipResponseWriter{w: w}
}
count++
c := Page(count)
h := templ.Handler(c)
h.ErrorHandler = func(r *http.Request, err error) http.Handler {
slog.Error("failed to render template", slog.Any("error", err))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
})
}
h.ServeHTTP(w, r)
})
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
if err != nil {
fmt.Printf("Error listening: %v\n", err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,17 @@
package main
import "fmt"
templ Page(count int) {
<!DOCTYPE html>
<html>
<head>
<title>templ test page</title>
</head>
<body>
<h1>Count</h1>
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
<div data-testid="modification">Original</div>
</body>
</html>
}

View File

@@ -0,0 +1,55 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package main
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
import "fmt"
func Page(count int) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ`, Line: 13, Col: 54}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,166 @@
package watcher
import (
"context"
"io/fs"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
func Recursive(
ctx context.Context,
path string,
watchPattern *regexp.Regexp,
out chan fsnotify.Event,
errors chan error,
) (w *RecursiveWatcher, err error) {
fsnw, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors)
go w.loop()
return w, w.Add(path)
}
func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher {
return &RecursiveWatcher{
ctx: ctx,
w: w,
WatchPattern: watchPattern,
Events: events,
Errors: errors,
timers: make(map[timerKey]*time.Timer),
}
}
// WalkFiles walks the file tree rooted at path, sending a Create event for each
// file it encounters.
func WalkFiles(ctx context.Context, path string, watchPattern *regexp.Regexp, out chan fsnotify.Event) (err error) {
rootPath := path
fileSystem := os.DirFS(rootPath)
return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error {
if err != nil {
return nil
}
absPath, err := filepath.Abs(filepath.Join(rootPath, path))
if err != nil {
return nil
}
if info.IsDir() && shouldSkipDir(absPath) {
return filepath.SkipDir
}
if !watchPattern.MatchString(absPath) {
return nil
}
out <- fsnotify.Event{
Name: absPath,
Op: fsnotify.Create,
}
return nil
})
}
type RecursiveWatcher struct {
ctx context.Context
w *fsnotify.Watcher
WatchPattern *regexp.Regexp
Events chan fsnotify.Event
Errors chan error
timerMu sync.Mutex
timers map[timerKey]*time.Timer
}
type timerKey struct {
name string
op fsnotify.Op
}
func timerKeyFromEvent(event fsnotify.Event) timerKey {
return timerKey{
name: event.Name,
op: event.Op,
}
}
func (w *RecursiveWatcher) Close() error {
return w.w.Close()
}
func (w *RecursiveWatcher) loop() {
for {
select {
case <-w.ctx.Done():
return
case event, ok := <-w.w.Events:
if !ok {
return
}
if event.Has(fsnotify.Create) {
if err := w.Add(event.Name); err != nil {
w.Errors <- err
}
}
// Only notify on templ related files.
if !w.WatchPattern.MatchString(event.Name) {
continue
}
tk := timerKeyFromEvent(event)
w.timerMu.Lock()
t, ok := w.timers[tk]
w.timerMu.Unlock()
if !ok {
t = time.AfterFunc(100*time.Millisecond, func() {
w.Events <- event
})
w.timerMu.Lock()
w.timers[tk] = t
w.timerMu.Unlock()
continue
}
t.Reset(100 * time.Millisecond)
case err, ok := <-w.w.Errors:
if !ok {
return
}
w.Errors <- err
}
}
}
func (w *RecursiveWatcher) Add(dir string) error {
return filepath.WalkDir(dir, func(dir string, info os.DirEntry, err error) error {
if err != nil {
return nil
}
if !info.IsDir() {
return nil
}
if shouldSkipDir(dir) {
return filepath.SkipDir
}
return w.w.Add(dir)
})
}
func shouldSkipDir(dir string) bool {
if dir == "." {
return false
}
if dir == "vendor" || dir == "node_modules" {
return true
}
_, name := path.Split(dir)
// These directories are ignored by the Go tool.
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
return true
}
return false
}

View File

@@ -0,0 +1,133 @@
package watcher
import (
"context"
"fmt"
"regexp"
"testing"
"time"
"github.com/fsnotify/fsnotify"
)
func TestWatchDebouncesDuplicates(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(300 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 1 {
t.Errorf("expected 1 event, got %d", count)
}
return
}
}
}
func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
tests := []struct {
event1 fsnotify.Event
event2 fsnotify.Event
}{
// Different files
{fsnotify.Event{Name: "test.templ"}, fsnotify.Event{Name: "test2.templ"}},
// Different operations
{
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
fsnotify.Event{Name: "test.templ", Op: fsnotify.Write},
},
// Different operations and files
{
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
fsnotify.Event{Name: "test2.templ", Op: fsnotify.Write},
},
}
for _, test := range tests {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- test.event1
rw.w.Events <- test.event2
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(300 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 2 {
t.Errorf("expected 2 event, got %d", count)
}
return
}
}
}
}
func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
w := &fsnotify.Watcher{
Events: make(chan fsnotify.Event),
}
events := make(chan fsnotify.Event, 2)
errors := make(chan error)
watchPattern, err := regexp.Compile(".*")
if err != nil {
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
}
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
go func() {
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
<-time.After(200 * time.Millisecond)
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
cancel()
close(rw.w.Events)
}()
rw.loop()
count := 0
exp := time.After(500 * time.Millisecond)
for {
select {
case <-rw.Events:
count++
case <-exp:
if count != 2 {
t.Errorf("expected 2 event, got %d", count)
}
return
}
}
}

View File

@@ -0,0 +1,174 @@
package imports
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"path"
"slices"
"strconv"
"strings"
goparser "go/parser"
"golang.org/x/sync/errgroup"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/imports"
"github.com/a-h/templ/generator"
"github.com/a-h/templ/parser/v2"
)
var internalImports = []string{"github.com/a-h/templ", "github.com/a-h/templ/runtime"}
func convertTemplToGoURI(templURI string) (isTemplFile bool, goURI string) {
base, fileName := path.Split(templURI)
if !strings.HasSuffix(fileName, ".templ") {
return
}
return true, base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go")
}
var fset = token.NewFileSet()
func updateImports(name, src string) (updated []*ast.ImportSpec, err error) {
// Apply auto imports.
updatedGoCode, err := imports.Process(name, []byte(src), nil)
if err != nil {
return updated, fmt.Errorf("failed to process go code %q: %w", src, err)
}
// Get updated imports.
gofile, err := goparser.ParseFile(fset, name, updatedGoCode, goparser.ImportsOnly)
if err != nil {
return updated, fmt.Errorf("failed to get imports from updated go code: %w", err)
}
for _, imp := range gofile.Imports {
if !slices.Contains(internalImports, strings.Trim(imp.Path.Value, "\"")) {
updated = append(updated, imp)
}
}
return updated, nil
}
func Process(t parser.TemplateFile) (parser.TemplateFile, error) {
if t.Filepath == "" {
return t, nil
}
isTemplFile, fileName := convertTemplToGoURI(t.Filepath)
if !isTemplFile {
return t, fmt.Errorf("invalid filepath: %s", t.Filepath)
}
// The first node always contains existing imports.
// If there isn't one, create it.
if len(t.Nodes) == 0 {
t.Nodes = append(t.Nodes, parser.TemplateFileGoExpression{})
}
// If there is one, ensure it is a Go expression.
if _, ok := t.Nodes[0].(parser.TemplateFileGoExpression); !ok {
t.Nodes = append([]parser.TemplateFileNode{parser.TemplateFileGoExpression{}}, t.Nodes...)
}
// Find all existing imports.
importsNode := t.Nodes[0].(parser.TemplateFileGoExpression)
// Generate code.
gw := bytes.NewBuffer(nil)
var updatedImports []*ast.ImportSpec
var eg errgroup.Group
eg.Go(func() (err error) {
if _, err := generator.Generate(t, gw); err != nil {
return fmt.Errorf("failed to generate go code: %w", err)
}
updatedImports, err = updateImports(fileName, gw.String())
if err != nil {
return fmt.Errorf("failed to get imports from generated go code: %w", err)
}
return nil
})
var firstGoNodeInTemplate *ast.File
// Update the template with the imports.
// Ensure that there is a Go expression to add the imports to as the first node.
eg.Go(func() (err error) {
firstGoNodeInTemplate, err = goparser.ParseFile(fset, fileName, t.Package.Expression.Value+"\n"+importsNode.Expression.Value, goparser.AllErrors|goparser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse imports section: %w", err)
}
return nil
})
// Wait for completion of both parts.
if err := eg.Wait(); err != nil {
return t, err
}
// Delete unused imports.
for _, imp := range firstGoNodeInTemplate.Imports {
if !containsImport(updatedImports, imp) {
name, path, err := getImportDetails(imp)
if err != nil {
return t, err
}
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
}
}
// Add imports, if there are any to add.
for _, imp := range updatedImports {
if !containsImport(firstGoNodeInTemplate.Imports, imp) {
name, path, err := getImportDetails(imp)
if err != nil {
return t, err
}
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
}
}
// Edge case: reinsert the import to use import syntax without parentheses.
if len(firstGoNodeInTemplate.Imports) == 1 {
name, path, err := getImportDetails(firstGoNodeInTemplate.Imports[0])
if err != nil {
return t, err
}
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
}
// Write out the Go code with the imports.
updatedGoCode := new(strings.Builder)
err := format.Node(updatedGoCode, fset, firstGoNodeInTemplate)
if err != nil {
return t, fmt.Errorf("failed to write updated go code: %w", err)
}
// Remove the package statement from the node, by cutting the first line of the file.
importsNode.Expression.Value = strings.TrimSpace(strings.SplitN(updatedGoCode.String(), "\n", 2)[1])
if len(updatedImports) == 0 && importsNode.Expression.Value == "" {
t.Nodes = t.Nodes[1:]
return t, nil
}
t.Nodes[0] = importsNode
return t, nil
}
func getImportDetails(imp *ast.ImportSpec) (name, importPath string, err error) {
if imp.Name != nil {
name = imp.Name.Name
}
if imp.Path != nil {
importPath, err = strconv.Unquote(imp.Path.Value)
if err != nil {
err = fmt.Errorf("failed to unquote package path %s: %w", imp.Path.Value, err)
return
}
}
return name, importPath, nil
}
func containsImport(imports []*ast.ImportSpec, spec *ast.ImportSpec) bool {
for _, imp := range imports {
if imp.Path.Value == spec.Path.Value {
return true
}
}
return false
}

View File

@@ -0,0 +1,154 @@
package imports
import (
"bytes"
"os"
"path"
"path/filepath"
"strings"
"testing"
"github.com/a-h/templ/cmd/templ/testproject"
"github.com/a-h/templ/parser/v2"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/txtar"
)
func TestFormatting(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txtar")
if len(files) == 0 {
t.Errorf("no test files found")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatalf("failed to parse txtar file: %v", err)
}
if len(a.Files) != 2 {
t.Fatalf("expected 2 files, got %d", len(a.Files))
}
template, err := parser.ParseString(clean(a.Files[0].Data))
if err != nil {
t.Fatalf("failed to parse %v", err)
}
template.Filepath = a.Files[0].Name
tf, err := Process(template)
if err != nil {
t.Fatalf("failed to process file: %v", err)
}
expected := string(a.Files[1].Data)
actual := new(strings.Builder)
if err := tf.Write(actual); err != nil {
t.Fatalf("failed to write template file: %v", err)
}
if diff := cmp.Diff(expected, actual.String()); diff != "" {
t.Errorf("%s:\n%s", file, diff)
t.Errorf("expected:\n%s", showWhitespace(expected))
t.Errorf("actual:\n%s", showWhitespace(actual.String()))
}
})
}
}
func showWhitespace(s string) string {
s = strings.ReplaceAll(s, "\n", "⏎\n")
s = strings.ReplaceAll(s, "\t", "→")
s = strings.ReplaceAll(s, " ", "·")
return s
}
func clean(b []byte) string {
b = bytes.ReplaceAll(b, []byte("$\n"), []byte("\n"))
b = bytes.TrimSuffix(b, []byte("\n"))
return string(b)
}
func TestImport(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
return
}
tests := []struct {
name string
src string
assertions func(t *testing.T, updated string)
}{
{
name: "un-named imports are removed",
src: `package main
import "fmt"
import "github.com/a-h/templ/cmd/templ/testproject/css-classes"
templ Page(count int) {
{ fmt.Sprintf("%d", count) }
{ cssclasses.Header }
}
`,
assertions: func(t *testing.T, updated string) {
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 0 {
t.Errorf("expected un-named import to be removed, but got %d instance of it", count)
}
},
},
{
name: "named imports are retained",
src: `package main
import "fmt"
import cssclasses "github.com/a-h/templ/cmd/templ/testproject/css-classes"
templ Page(count int) {
{ fmt.Sprintf("%d", count) }
{ cssclasses.Header }
}
`,
assertions: func(t *testing.T, updated string) {
if count := strings.Count(updated, "cssclasses \"github.com/a-h/templ/cmd/templ/testproject/css-classes\""); count != 1 {
t.Errorf("expected named import to be retained, got %d instances of it", count)
}
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 1 {
t.Errorf("expected one import, got %d", count)
}
},
},
}
for _, test := range tests {
// Create test project.
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
if err != nil {
t.Fatalf("failed to create test project: %v", err)
}
defer os.RemoveAll(dir)
// Load the templates.templ file.
filePath := path.Join(dir, "templates.templ")
err = os.WriteFile(filePath, []byte(test.src), 0660)
if err != nil {
t.Fatalf("failed to write file: %v", err)
}
// Parse the new file.
template, err := parser.Parse(filePath)
if err != nil {
t.Fatalf("failed to parse %v", err)
}
template.Filepath = filePath
tf, err := Process(template)
if err != nil {
t.Fatalf("failed to process file: %v", err)
}
// Write it back out after processing.
buf := new(strings.Builder)
if err := tf.Write(buf); err != nil {
t.Fatalf("failed to write template file: %v", err)
}
// Assert.
test.assertions(t, buf.String())
}
}

View File

@@ -0,0 +1,12 @@
-- fmt_templ.templ --
package test
// Comment on variable or function.
var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
package test
import "fmt"
// Comment on variable or function.
var x = fmt.Sprintf("Hello")

View File

@@ -0,0 +1,28 @@
-- fmt_templ.templ --
// Comments before.
/*
Some more comments
*/
package test
templ test() {
<div>Hello</div>
}
// Comment on variable or function.
var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
// Comments before.
/*
Some more comments
*/
package test
import "fmt"
templ test() {
<div>Hello</div>
}
// Comment on variable or function.
var x = fmt.Sprintf("Hello")

View File

@@ -0,0 +1,14 @@
-- fmt.templ --
package test
import "strconv"
templ Hello() {
<div>Hello</div>
}
-- fmt.templ --
package test
templ Hello() {
<div>Hello</div>
}

View File

@@ -0,0 +1,15 @@
-- fmt_templ.templ --
package test
const x = 123
var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
package test
import "fmt"
const x = 123
var x = fmt.Sprintf("Hello")

View File

@@ -0,0 +1,22 @@
-- fmt.templ --
package test
import (
"strings"
"fmt"
"strconv"
)
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))
-- fmt.templ --
package test
import (
"fmt"
"strings"
"strconv"
)
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))

View File

@@ -0,0 +1,21 @@
-- fmt.templ --
package test
import (
"fmt"
"strconv"
)
var _, _ = fmt.Print(strconv.Quote("Hello"))
-- fmt.templ --
package test
import (
"fmt"
"strconv"
)
var _, _ = fmt.Print(strconv.Quote("Hello"))

View File

@@ -0,0 +1,10 @@
-- fmt_templ.templ --
package test
var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
package test
import "fmt"
var x = fmt.Sprintf("Hello")

View File

@@ -0,0 +1,19 @@
-- fmt_templ.templ --
package test
import (
sconv "strconv"
)
// Comment on variable or function.
var x = fmt.Sprintf(sconv.Quote("Hello"))
-- fmt_templ.templ --
package test
import (
"fmt"
sconv "strconv"
)
// Comment on variable or function.
var x = fmt.Sprintf(sconv.Quote("Hello"))

View File

@@ -0,0 +1,16 @@
-- fmt_templ.templ --
package test
import (
sconv "strconv"
)
// Comment on variable or function.
var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
package test
import "fmt"
// Comment on variable or function.
var x = fmt.Sprintf("Hello")

View File

@@ -0,0 +1,12 @@
-- fmt.templ --
package test
templ Hello() {
<div>Hello</div>
}
-- fmt.templ --
package test
templ Hello() {
<div>Hello</div>
}

View File

@@ -0,0 +1,20 @@
-- fmt.templ --
package test
func test() {
// Do nothing.
}
templ Hello() {
<div>Hello</div>
}
-- fmt.templ --
package test
func test() {
// Do nothing.
}
templ Hello() {
<div>Hello</div>
}

View File

@@ -0,0 +1,14 @@
-- fmt.templ --
package test
templ Hello(name string) {
{ fmt.Sprintf("Hello, %s!", name) }
}
-- fmt.templ --
package test
import "fmt"
templ Hello(name string) {
{ fmt.Sprintf("Hello, %s!", name) }
}

View File

@@ -0,0 +1,21 @@
-- fmt.templ --
package test
templ Hello(name string) {
<div id={ strconv.Atoi("123") }>
{ fmt.Sprintf("Hello, %s!", name) }
</div>
}
-- fmt.templ --
package test
import (
"fmt"
"strconv"
)
templ Hello(name string) {
<div id={ strconv.Atoi("123") }>
{ fmt.Sprintf("Hello, %s!", name) }
</div>
}

View File

@@ -0,0 +1,157 @@
package infocmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"runtime"
"strings"
"github.com/a-h/templ"
"github.com/a-h/templ/cmd/templ/lspcmd/pls"
)
type Arguments struct {
JSON bool `flag:"json" help:"Output info as JSON."`
}
type Info struct {
OS struct {
GOOS string `json:"goos"`
GOARCH string `json:"goarch"`
} `json:"os"`
Go ToolInfo `json:"go"`
Gopls ToolInfo `json:"gopls"`
Templ ToolInfo `json:"templ"`
}
type ToolInfo struct {
Location string `json:"location"`
Version string `json:"version"`
OK bool `json:"ok"`
Message string `json:"message,omitempty"`
}
func getGoInfo() (d ToolInfo) {
// Find Go.
var err error
d.Location, err = exec.LookPath("go")
if err != nil {
d.Message = fmt.Sprintf("failed to find go: %v", err)
return
}
// Run go to find the version.
cmd := exec.Command(d.Location, "version")
v, err := cmd.Output()
if err != nil {
d.Message = fmt.Sprintf("failed to get go version, check that Go is installed: %v", err)
return
}
d.Version = strings.TrimSpace(string(v))
d.OK = true
return
}
func getGoplsInfo() (d ToolInfo) {
var err error
d.Location, err = pls.FindGopls()
if err != nil {
d.Message = fmt.Sprintf("failed to find gopls: %v", err)
return
}
cmd := exec.Command(d.Location, "version")
v, err := cmd.Output()
if err != nil {
d.Message = fmt.Sprintf("failed to get gopls version: %v", err)
return
}
d.Version = strings.TrimSpace(string(v))
d.OK = true
return
}
func getTemplInfo() (d ToolInfo) {
// Find templ.
var err error
d.Location, err = findTempl()
if err != nil {
d.Message = err.Error()
return
}
// Run templ to find the version.
cmd := exec.Command(d.Location, "version")
v, err := cmd.Output()
if err != nil {
d.Message = fmt.Sprintf("failed to get templ version: %v", err)
return
}
d.Version = strings.TrimSpace(string(v))
if d.Version != templ.Version() {
d.Message = fmt.Sprintf("version mismatch - you're running %q at the command line, but the version in the path is %q", templ.Version(), d.Version)
return
}
d.OK = true
return
}
func findTempl() (location string, err error) {
executableName := "templ"
if runtime.GOOS == "windows" {
executableName = "templ.exe"
}
executableName, err = exec.LookPath(executableName)
if err == nil {
// Found on the path.
return executableName, nil
}
// Unexpected error.
if !errors.Is(err, exec.ErrNotFound) {
return "", fmt.Errorf("unexpected error looking for templ: %w", err)
}
return "", fmt.Errorf("templ is not in the path (%q). You can install templ with `go install github.com/a-h/templ/cmd/templ@latest`", os.Getenv("PATH"))
}
func getInfo() (d Info) {
d.OS.GOOS = runtime.GOOS
d.OS.GOARCH = runtime.GOARCH
d.Go = getGoInfo()
d.Gopls = getGoplsInfo()
d.Templ = getTemplInfo()
return
}
func Run(ctx context.Context, log *slog.Logger, stdout io.Writer, args Arguments) (err error) {
info := getInfo()
if args.JSON {
enc := json.NewEncoder(stdout)
enc.SetIndent("", " ")
return enc.Encode(info)
}
log.Info("os", slog.String("goos", info.OS.GOOS), slog.String("goarch", info.OS.GOARCH))
logInfo(ctx, log, "go", info.Go)
logInfo(ctx, log, "gopls", info.Gopls)
logInfo(ctx, log, "templ", info.Templ)
return nil
}
func logInfo(ctx context.Context, log *slog.Logger, name string, ti ToolInfo) {
level := slog.LevelInfo
if !ti.OK {
level = slog.LevelError
}
args := []any{
slog.String("location", ti.Location),
slog.String("version", ti.Version),
}
if ti.Message != "" {
args = append(args, slog.String("message", ti.Message))
}
log.Log(ctx, level, name, args...)
}

View File

@@ -0,0 +1,130 @@
package httpdebug
import (
"encoding/json"
"io"
"log/slog"
"net/http"
"net/url"
"github.com/a-h/templ"
"github.com/a-h/templ/cmd/templ/lspcmd/proxy"
"github.com/a-h/templ/cmd/templ/visualize"
)
var log *slog.Logger
func NewHandler(l *slog.Logger, s *proxy.Server) http.Handler {
m := http.NewServeMux()
log = l
m.HandleFunc("/templ", func(w http.ResponseWriter, r *http.Request) {
uri := r.URL.Query().Get("uri")
c, ok := s.TemplSource.Get(uri)
if !ok {
Error(w, "uri not found", http.StatusNotFound)
return
}
String(w, c.String())
})
m.HandleFunc("/sourcemap", func(w http.ResponseWriter, r *http.Request) {
uri := r.URL.Query().Get("uri")
sm, ok := s.SourceMapCache.Get(uri)
if !ok {
Error(w, "uri not found", http.StatusNotFound)
return
}
JSON(w, sm.SourceLinesToTarget)
})
m.HandleFunc("/go", func(w http.ResponseWriter, r *http.Request) {
uri := r.URL.Query().Get("uri")
c, ok := s.GoSource[uri]
if !ok {
Error(w, "uri not found", http.StatusNotFound)
return
}
String(w, c)
})
m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
uri := r.URL.Query().Get("uri")
if uri == "" {
// List all URIs.
if err := list(s.TemplSource.URIs()).Render(r.Context(), w); err != nil {
Error(w, "failed to list URIs", http.StatusInternalServerError)
}
return
}
// Assume we've got a URI.
templSource, ok := s.TemplSource.Get(uri)
if !ok {
if !ok {
Error(w, "uri not found in document contents", http.StatusNotFound)
return
}
}
goSource, ok := s.GoSource[uri]
if !ok {
if !ok {
Error(w, "uri not found in document contents", http.StatusNotFound)
return
}
}
sm, ok := s.SourceMapCache.Get(uri)
if !ok {
Error(w, "uri not found", http.StatusNotFound)
return
}
if err := visualize.HTML(uri, templSource.String(), goSource, sm).Render(r.Context(), w); err != nil {
Error(w, "failed to visualize HTML", http.StatusInternalServerError)
}
})
return m
}
func getMapURL(uri string) templ.SafeURL {
return withQuery("/", uri)
}
func getSourceMapURL(uri string) templ.SafeURL {
return withQuery("/sourcemap", uri)
}
func getTemplURL(uri string) templ.SafeURL {
return withQuery("/templ", uri)
}
func getGoURL(uri string) templ.SafeURL {
return withQuery("/go", uri)
}
func withQuery(path, uri string) templ.SafeURL {
q := make(url.Values)
q.Set("uri", uri)
u := &url.URL{
Path: path,
RawPath: path,
RawQuery: q.Encode(),
}
return templ.SafeURL(u.String())
}
func JSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
if err := enc.Encode(v); err != nil {
log.Error("failed to write JSON response", slog.Any("error", err))
}
}
func String(w http.ResponseWriter, s string) {
if _, err := io.WriteString(w, s); err != nil {
log.Error("failed to write string response", slog.Any("error", err))
}
}
func Error(w http.ResponseWriter, msg string, status int) {
w.WriteHeader(status)
if _, err := io.WriteString(w, msg); err != nil {
log.Error("failed to write error response", slog.Any("error", err))
}
}

View File

@@ -0,0 +1,22 @@
package httpdebug
templ list(uris []string) {
<table>
<tr>
<th>File</th>
<th></th>
<th></th>
<th></th>
<th></th>
</tr>
for _, uri := range uris {
<tr>
<td>{ uri }</td>
<td><a href={ getMapURL(uri) }>Mapping</a></td>
<td><a href={ getSourceMapURL(uri) }>Source Map</a></td>
<td><a href={ getTemplURL(uri) }>Templ</a></td>
<td><a href={ getGoURL(uri) }>Go</a></td>
</tr>
}
</table>
}

View File

@@ -0,0 +1,99 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package httpdebug
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
func list(uris []string) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<table><tr><th>File</th><th></th><th></th><th></th><th></th></tr>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
for _, uri := range uris {
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "<tr><td>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(uri)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/lspcmd/httpdebug/list.templ`, Line: 14, Col: 13}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</td><td><a href=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var3 templ.SafeURL = getMapURL(uri)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var3)))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "\">Mapping</a></td><td><a href=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var4 templ.SafeURL = getSourceMapURL(uri)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var4)))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "\">Source Map</a></td><td><a href=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var5 templ.SafeURL = getTemplURL(uri)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var5)))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "\">Templ</a></td><td><a href=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var6 templ.SafeURL = getGoURL(uri)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var6)))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "\">Go</a></td></tr>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</table>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,957 @@
package lspcmd
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"sync"
"testing"
"time"
"unicode/utf8"
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
"github.com/a-h/templ/cmd/templ/lspcmd/lspdiff"
"github.com/a-h/templ/cmd/templ/testproject"
"github.com/a-h/templ/lsp/jsonrpc2"
"github.com/a-h/templ/lsp/protocol"
"github.com/a-h/templ/lsp/uri"
"github.com/google/go-cmp/cmp"
)
func TestCompletion(t *testing.T) {
if testing.Short() {
return
}
ctx, cancel := context.WithCancel(context.Background())
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
defer cancel()
templFile, err := os.ReadFile(appDir + "/templates.templ")
if err != nil {
t.Errorf("failed to read file %q: %v", appDir+"/templates.templ", err)
return
}
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
TextDocument: protocol.TextDocumentItem{
URI: uri.URI("file://" + appDir + "/templates.templ"),
LanguageID: "templ",
Version: 1,
Text: string(templFile),
},
})
if err != nil {
t.Errorf("failed to register open file: %v", err)
return
}
log.Info("Calling completion")
globalSnippetsLen := 1
// Edit the file.
// Replace:
// <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
// With various tests:
// <div data-testid="count">{ f
tests := []struct {
line int
replacement string
cursor string
assert func(t *testing.T, cl *protocol.CompletionList) (msg string, ok bool)
}{
{
line: 13,
replacement: ` <div data-testid="count">{ `,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
if actual != nil && len(actual.Items) != globalSnippetsLen {
return "expected completion list to be empty", false
}
return "", true
},
},
{
line: 13,
replacement: ` <div data-testid="count">{ fmt.`,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
if !lspdiff.CompletionListContainsText(actual, "fmt.Sprintf") {
return fmt.Sprintf("expected fmt.Sprintf to be in the completion list, but got %#v", actual), false
}
return "", true
},
},
{
line: 13,
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d",`,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
if actual != nil && len(actual.Items) != globalSnippetsLen {
return "expected completion list to be empty", false
}
return "", true
},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
// Edit the file.
updated := testproject.MustReplaceLine(string(templFile), test.line, test.replacement)
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
TextDocument: protocol.VersionedTextDocumentIdentifier{
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
Version: int32(i + 2),
},
ContentChanges: []protocol.TextDocumentContentChangeEvent{
{
Range: nil,
Text: updated,
},
},
})
if err != nil {
t.Errorf("failed to change file: %v", err)
return
}
// Give CI/CD pipeline executors some time because they're often quite slow.
var ok bool
var msg string
for i := 0; i < 3; i++ {
actual, err := server.Completion(ctx, &protocol.CompletionParams{
Context: &protocol.CompletionContext{
TriggerCharacter: ".",
TriggerKind: protocol.CompletionTriggerKindTriggerCharacter,
},
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
// Positions are zero indexed.
Position: protocol.Position{
Line: uint32(test.line - 1),
Character: uint32(len(test.cursor) - 1),
},
},
})
if err != nil {
t.Errorf("failed to get completion: %v", err)
return
}
msg, ok = test.assert(t, actual)
if !ok {
break
}
time.Sleep(time.Millisecond * 500)
}
if !ok {
t.Error(msg)
}
})
}
log.Info("Completed test")
}
func TestHover(t *testing.T) {
if testing.Short() {
return
}
ctx, cancel := context.WithCancel(context.Background())
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
defer cancel()
templFile, err := os.ReadFile(appDir + "/templates.templ")
if err != nil {
t.Fatalf("failed to read file %q: %v", appDir+"/templates.templ", err)
}
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
TextDocument: protocol.TextDocumentItem{
URI: uri.URI("file://" + appDir + "/templates.templ"),
LanguageID: "templ",
Version: 1,
Text: string(templFile),
},
})
if err != nil {
t.Errorf("failed to register open file: %v", err)
return
}
log.Info("Calling hover")
// Edit the file.
// Replace:
// <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
// With various tests:
// <div data-testid="count">{ f
tests := []struct {
line int
replacement string
cursor string
assert func(t *testing.T, hr *protocol.Hover) (msg string, ok bool)
}{
{
line: 13,
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>`,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
expectedHover := protocol.Hover{
Contents: protocol.MarkupContent{
Kind: "markdown",
Value: "```go\npackage fmt\n```\n\n---\n\n[`fmt` on pkg.go.dev](https://pkg.go.dev/fmt)",
},
}
if diff := lspdiff.Hover(expectedHover, *actual); diff != "" {
return fmt.Sprintf("unexpected hover: %v\n\n: markdown: %#v", diff, actual.Contents.Value), false
}
return "", true
},
},
{
line: 13,
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>`,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
expectedHover := protocol.Hover{
Contents: protocol.MarkupContent{
Kind: "markdown",
Value: "```go\nfunc fmt.Sprintf(format string, a ...any) string\n```\n\n---\n\nSprintf formats according to a format specifier and returns the resulting string.\n\n\n---\n\n[`fmt.Sprintf` on pkg.go.dev](https://pkg.go.dev/fmt#Sprintf)",
},
}
if actual == nil {
return "expected hover to be non-nil", false
}
if diff := lspdiff.Hover(expectedHover, *actual); diff != "" {
return fmt.Sprintf("unexpected hover: %v", diff), false
}
return "", true
},
},
{
line: 19,
replacement: `var nihao = "你好"`,
cursor: ` ^`,
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
// There's nothing to hover, just want to make sure it doesn't panic.
return "", true
},
},
{
line: 19,
replacement: `var nihao = "你好"`,
cursor: ` ^`, // Your text editor might not render this well, but it's the hao.
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
// There's nothing to hover, just want to make sure it doesn't panic.
return "", true
},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
// Put the file back to the initial point.
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
TextDocument: protocol.VersionedTextDocumentIdentifier{
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
Version: int32(i + 2),
},
ContentChanges: []protocol.TextDocumentContentChangeEvent{
{
Range: nil,
Text: string(templFile),
},
},
})
if err != nil {
t.Errorf("failed to change file: %v", err)
return
}
// Give CI/CD pipeline executors some time because they're often quite slow.
var ok bool
var msg string
for i := 0; i < 3; i++ {
lspCharIndex, err := runeIndexToUTF8ByteIndex(test.replacement, len(test.cursor)-1)
if err != nil {
t.Error(err)
}
actual, err := server.Hover(ctx, &protocol.HoverParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
// Positions are zero indexed.
Position: protocol.Position{
Line: uint32(test.line - 1),
Character: lspCharIndex,
},
},
})
if err != nil {
t.Errorf("failed to hover: %v", err)
return
}
msg, ok = test.assert(t, actual)
if !ok {
break
}
time.Sleep(time.Millisecond * 500)
}
if !ok {
t.Error(msg)
}
})
}
}
func TestReferences(t *testing.T) {
if testing.Short() {
return
}
ctx, cancel := context.WithCancel(context.Background())
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
return
}
defer teardown(t)
defer cancel()
log.Info("Calling References")
tests := []struct {
line int
character int
filename string
assert func(t *testing.T, l []protocol.Location) (msg string, ok bool)
}{
{
// this is the definition of the templ function in the templates.templ file.
line: 5,
character: 9,
filename: "/templates.templ",
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
expectedReference := []protocol.Location{
{
// This is the usage of the templ function in the main.go file.
URI: uri.URI("file://" + appDir + "/main.go"),
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(24),
Character: uint32(7),
},
End: protocol.Position{
Line: uint32(24),
Character: uint32(11),
},
},
},
}
if diff := lspdiff.References(expectedReference, actual); diff != "" {
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
}
return "", true
},
},
{
// this is the definition of the struct in the templates.templ file.
line: 21,
character: 9,
filename: "/templates.templ",
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
expectedReference := []protocol.Location{
{
// This is the usage of the struct in the templates.templ file.
URI: uri.URI("file://" + appDir + "/templates.templ"),
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(24),
Character: uint32(8),
},
End: protocol.Position{
Line: uint32(24),
Character: uint32(14),
},
},
},
}
if diff := lspdiff.References(expectedReference, actual); diff != "" {
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
}
return "", true
},
},
{
// this test is for inclusions from a remote file that has not been explicitly called with didOpen
line: 3,
character: 9,
filename: "/remotechild.templ",
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
expectedReference := []protocol.Location{
{
URI: uri.URI("file://" + appDir + "/remoteparent.templ"),
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(3),
Character: uint32(2),
},
End: protocol.Position{
Line: uint32(3),
Character: uint32(8),
},
},
},
{
URI: uri.URI("file://" + appDir + "/remoteparent.templ"),
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(7),
Character: uint32(2),
},
End: protocol.Position{
Line: uint32(7),
Character: uint32(8),
},
},
},
}
if diff := lspdiff.References(expectedReference, actual); diff != "" {
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
}
return "", true
},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
// Give CI/CD pipeline executors some time because they're often quite slow.
var ok bool
var msg string
for i := 0; i < 3; i++ {
if err != nil {
t.Error(err)
return
}
actual, err := server.References(ctx, &protocol.ReferenceParams{
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + test.filename),
},
// Positions are zero indexed.
Position: protocol.Position{
Line: uint32(test.line - 1),
Character: uint32(test.character - 1),
},
},
})
if err != nil {
t.Errorf("failed to get references: %v", err)
return
}
msg, ok = test.assert(t, actual)
if !ok {
break
}
time.Sleep(time.Millisecond * 500)
}
if !ok {
t.Error(msg)
}
})
}
}
func TestCodeAction(t *testing.T) {
if testing.Short() {
return
}
ctx, cancel := context.WithCancel(context.Background())
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
defer cancel()
templFile, err := os.ReadFile(appDir + "/templates.templ")
if err != nil {
t.Fatalf("failed to read file %q: %v", appDir+"/templates.templ", err)
}
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
TextDocument: protocol.TextDocumentItem{
URI: uri.URI("file://" + appDir + "/templates.templ"),
LanguageID: "templ",
Version: 1,
Text: string(templFile),
},
})
if err != nil {
t.Errorf("failed to register open file: %v", err)
return
}
log.Info("Calling codeAction")
tests := []struct {
line int
replacement string
cursor string
assert func(t *testing.T, hr []protocol.CodeAction) (msg string, ok bool)
}{
{
line: 25,
replacement: `var s = Struct{}`,
cursor: ` ^`,
assert: func(t *testing.T, actual []protocol.CodeAction) (msg string, ok bool) {
var expected []protocol.CodeAction
// To support code actions, update cmd/templ/lspcmd/proxy/server.go and add the
// Title (e.g. Organize Imports, or Fill Struct) to the supportedCodeActions map.
// Some Code Actions are simple edits, so all that is needed is for the server
// to remap the source code positions.
// However, other Code Actions are commands, where the arguments must be rewritten
// and will need to be handled individually.
if diff := lspdiff.CodeAction(expected, actual); diff != "" {
return fmt.Sprintf("unexpected codeAction: %v", diff), false
}
return "", true
},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
// Put the file back to the initial point.
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
TextDocument: protocol.VersionedTextDocumentIdentifier{
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
Version: int32(i + 2),
},
ContentChanges: []protocol.TextDocumentContentChangeEvent{
{
Range: nil,
Text: string(templFile),
},
},
})
if err != nil {
t.Errorf("failed to change file: %v", err)
return
}
// Give CI/CD pipeline executors some time because they're often quite slow.
var ok bool
var msg string
for i := 0; i < 3; i++ {
lspCharIndex, err := runeIndexToUTF8ByteIndex(test.replacement, len(test.cursor)-1)
if err != nil {
t.Error(err)
}
actual, err := server.CodeAction(ctx, &protocol.CodeActionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: uri.URI("file://" + appDir + "/templates.templ"),
},
Range: protocol.Range{
Start: protocol.Position{
Line: uint32(test.line - 1),
Character: lspCharIndex,
},
End: protocol.Position{
Line: uint32(test.line - 1),
Character: lspCharIndex + 1,
},
},
})
if err != nil {
t.Errorf("failed code action: %v", err)
return
}
msg, ok = test.assert(t, actual)
if !ok {
break
}
time.Sleep(time.Millisecond * 500)
}
if !ok {
t.Error(msg)
}
})
}
}
func TestDocumentSymbol(t *testing.T) {
if testing.Short() {
return
}
ctx, cancel := context.WithCancel(context.Background())
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
if err != nil {
t.Fatalf("failed to setup test: %v", err)
}
defer teardown(t)
defer cancel()
tests := []struct {
uri string
expect []protocol.SymbolInformationOrDocumentSymbol
}{
{
uri: "file://" + appDir + "/templates.templ",
expect: []protocol.SymbolInformationOrDocumentSymbol{
{
SymbolInformation: &protocol.SymbolInformation{
Name: "Page",
Kind: protocol.SymbolKindFunction,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 11, Character: 0},
End: protocol.Position{Line: 50, Character: 1},
},
},
},
},
{
SymbolInformation: &protocol.SymbolInformation{
Name: "nihao",
Kind: protocol.SymbolKindVariable,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 18, Character: 4},
End: protocol.Position{Line: 18, Character: 16},
},
},
},
},
{
SymbolInformation: &protocol.SymbolInformation{
Name: "Struct",
Kind: protocol.SymbolKindStruct,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 20, Character: 5},
End: protocol.Position{Line: 22, Character: 1},
},
},
},
},
{
SymbolInformation: &protocol.SymbolInformation{
Name: "s",
Kind: protocol.SymbolKindVariable,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 24, Character: 4},
End: protocol.Position{Line: 24, Character: 16},
},
},
},
},
},
},
{
uri: "file://" + appDir + "/remoteparent.templ",
expect: []protocol.SymbolInformationOrDocumentSymbol{
{
SymbolInformation: &protocol.SymbolInformation{
Name: "RemoteInclusionTest",
Kind: protocol.SymbolKindFunction,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 9, Character: 0},
End: protocol.Position{Line: 35, Character: 1},
},
},
},
},
{
SymbolInformation: &protocol.SymbolInformation{
Name: "Remote2",
Kind: protocol.SymbolKindFunction,
Location: protocol.Location{
Range: protocol.Range{
Start: protocol.Position{Line: 37, Character: 0},
End: protocol.Position{Line: 63, Character: 1},
},
},
},
},
},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
actual, err := server.DocumentSymbol(ctx, &protocol.DocumentSymbolParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: uri.URI(test.uri),
},
})
if err != nil {
t.Errorf("failed to get document symbol: %v", err)
}
// Set expected URI.
for i, v := range test.expect {
if v.SymbolInformation != nil {
v.SymbolInformation.Location.URI = uri.URI(test.uri)
test.expect[i] = v
}
}
if err != nil {
t.Errorf("failed to convert expect to any slice: %v", err)
}
diff := cmp.Diff(test.expect, actual)
if diff != "" {
t.Errorf("unexpected document symbol: %v", diff)
}
})
}
}
func runeIndexToUTF8ByteIndex(s string, runeIndex int) (lspChar uint32, err error) {
for i, r := range []rune(s) {
if i == runeIndex {
break
}
l := utf8.RuneLen(r)
if l < 0 {
return 0, fmt.Errorf("invalid rune in string at index %d", runeIndex)
}
lspChar += uint32(l)
}
return lspChar, nil
}
func NewTestClient(log *slog.Logger) TestClient {
return TestClient{
log: log,
}
}
type TestClient struct {
log *slog.Logger
}
func (tc TestClient) Progress(ctx context.Context, params *protocol.ProgressParams) (err error) {
tc.log.Info("client: Received Progress", slog.Any("params", params))
return nil
}
func (tc TestClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) (err error) {
tc.log.Info("client: Received WorkDoneProgressCreate", slog.Any("params", params))
return nil
}
func (tc TestClient) LogMessage(ctx context.Context, params *protocol.LogMessageParams) (err error) {
tc.log.Info("client: Received LogMessage", slog.Any("params", params))
return nil
}
func (tc TestClient) PublishDiagnostics(ctx context.Context, params *protocol.PublishDiagnosticsParams) (err error) {
tc.log.Info("client: Received PublishDiagnostics", slog.Any("params", params))
return nil
}
func (tc TestClient) ShowMessage(ctx context.Context, params *protocol.ShowMessageParams) (err error) {
tc.log.Info("client: Received ShowMessage", slog.Any("params", params))
return nil
}
func (tc TestClient) ShowMessageRequest(ctx context.Context, params *protocol.ShowMessageRequestParams) (result *protocol.MessageActionItem, err error) {
return nil, nil
}
func (tc TestClient) Telemetry(ctx context.Context, params any) (err error) {
tc.log.Info("client: Received Telemetry", slog.Any("params", params))
return nil
}
func (tc TestClient) RegisterCapability(ctx context.Context, params *protocol.RegistrationParams,
) (err error) {
tc.log.Info("client: Received RegisterCapability", slog.Any("params", params))
return nil
}
func (tc TestClient) UnregisterCapability(ctx context.Context, params *protocol.UnregistrationParams) (err error) {
tc.log.Info("client: Received UnregisterCapability", slog.Any("params", params))
return nil
}
func (tc TestClient) ApplyEdit(ctx context.Context, params *protocol.ApplyWorkspaceEditParams) (result *protocol.ApplyWorkspaceEditResponse, err error) {
tc.log.Info("client: Received ApplyEdit", slog.Any("params", params))
return nil, nil
}
func (tc TestClient) Configuration(ctx context.Context, params *protocol.ConfigurationParams) (result []any, err error) {
tc.log.Info("client: Received Configuration", slog.Any("params", params))
return nil, nil
}
func (tc TestClient) WorkspaceFolders(ctx context.Context) (result []protocol.WorkspaceFolder, err error) {
tc.log.Info("client: Received WorkspaceFolders")
return nil, nil
}
func Setup(ctx context.Context, log *slog.Logger) (clientCtx context.Context, appDir string, client protocol.Client, server protocol.Server, teardown func(t *testing.T), err error) {
wd, err := os.Getwd()
if err != nil {
return ctx, appDir, client, server, teardown, fmt.Errorf("could not find working dir: %w", err)
}
moduleRoot, err := modcheck.WalkUp(wd)
if err != nil {
return ctx, appDir, client, server, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
}
appDir, err = testproject.Create(moduleRoot)
if err != nil {
return ctx, appDir, client, server, teardown, fmt.Errorf("failed to create test project: %v", err)
}
var wg sync.WaitGroup
var cmdErr error
// Copy from the LSP to the Client, and vice versa.
fromClient, toLSP := io.Pipe()
fromLSP, toClient := io.Pipe()
clientStream := jsonrpc2.NewStream(newStdRwc(log, "clientStream", toLSP, fromLSP))
serverStream := jsonrpc2.NewStream(newStdRwc(log, "serverStream", toClient, fromClient))
// Create the client that the server needs.
client = NewTestClient(log)
ctx, _, server = protocol.NewClient(ctx, client, clientStream, log)
wg.Add(1)
go func() {
defer wg.Done()
log.Info("Running")
// Create the server that the client needs.
cmdErr = run(ctx, log, serverStream, Arguments{})
if cmdErr != nil {
log.Error("Failed to run", slog.Any("error", cmdErr))
}
log.Info("Stopped")
}()
// Initialize.
ir, err := server.Initialize(ctx, &protocol.InitializeParams{
ClientInfo: &protocol.ClientInfo{},
Capabilities: protocol.ClientCapabilities{
Workspace: &protocol.WorkspaceClientCapabilities{
ApplyEdit: true,
WorkspaceEdit: &protocol.WorkspaceClientCapabilitiesWorkspaceEdit{
DocumentChanges: true,
},
WorkspaceFolders: true,
FileOperations: &protocol.WorkspaceClientCapabilitiesFileOperations{
DidCreate: true,
WillCreate: true,
DidRename: true,
WillRename: true,
DidDelete: true,
WillDelete: true,
},
},
TextDocument: &protocol.TextDocumentClientCapabilities{
Synchronization: &protocol.TextDocumentSyncClientCapabilities{
DidSave: true,
},
Completion: &protocol.CompletionTextDocumentClientCapabilities{
CompletionItem: &protocol.CompletionTextDocumentClientCapabilitiesItem{
SnippetSupport: true,
DeprecatedSupport: true,
InsertReplaceSupport: true,
},
},
Hover: &protocol.HoverTextDocumentClientCapabilities{},
SignatureHelp: &protocol.SignatureHelpTextDocumentClientCapabilities{},
Declaration: &protocol.DeclarationTextDocumentClientCapabilities{},
Definition: &protocol.DefinitionTextDocumentClientCapabilities{},
TypeDefinition: &protocol.TypeDefinitionTextDocumentClientCapabilities{},
Implementation: &protocol.ImplementationTextDocumentClientCapabilities{},
References: &protocol.ReferencesTextDocumentClientCapabilities{},
DocumentHighlight: &protocol.DocumentHighlightClientCapabilities{},
DocumentSymbol: &protocol.DocumentSymbolClientCapabilities{},
CodeAction: &protocol.CodeActionClientCapabilities{},
CodeLens: &protocol.CodeLensClientCapabilities{},
Formatting: &protocol.DocumentFormattingClientCapabilities{},
RangeFormatting: &protocol.DocumentRangeFormattingClientCapabilities{},
OnTypeFormatting: &protocol.DocumentOnTypeFormattingClientCapabilities{},
PublishDiagnostics: &protocol.PublishDiagnosticsClientCapabilities{},
Rename: &protocol.RenameClientCapabilities{},
FoldingRange: &protocol.FoldingRangeClientCapabilities{},
SelectionRange: &protocol.SelectionRangeClientCapabilities{},
CallHierarchy: &protocol.CallHierarchyClientCapabilities{},
SemanticTokens: &protocol.SemanticTokensClientCapabilities{},
LinkedEditingRange: &protocol.LinkedEditingRangeClientCapabilities{},
},
Window: &protocol.WindowClientCapabilities{},
General: &protocol.GeneralClientCapabilities{},
Experimental: nil,
},
WorkspaceFolders: []protocol.WorkspaceFolder{
{
URI: "file://" + appDir,
Name: "templ-test",
},
},
})
if err != nil {
log.Error("Failed to init", slog.Any("error", err))
}
if ir.ServerInfo.Name != "templ-lsp" {
return ctx, appDir, client, server, teardown, fmt.Errorf("expected server name to be templ-lsp, got %q", ir.ServerInfo.Name)
}
// Confirm initialization.
log.Info("Confirming initialization...")
if err = server.Initialized(ctx, &protocol.InitializedParams{}); err != nil {
return ctx, appDir, client, server, teardown, fmt.Errorf("failed to confirm initialization: %v", err)
}
log.Info("Initialized")
// Wait for exit.
teardown = func(t *testing.T) {
log.Info("Tearing down LSP")
wg.Wait()
if cmdErr != nil {
t.Errorf("failed to run lsp cmd: %v", err)
}
if err = os.RemoveAll(appDir); err != nil {
t.Errorf("failed to remove test dir %q: %v", appDir, err)
}
}
return ctx, appDir, client, server, teardown, err
}

View File

@@ -0,0 +1,42 @@
package lspdiff
import (
"github.com/a-h/templ/lsp/protocol"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
// This package provides a way to compare LSP protocol messages, ignoring irrelevant fields.
func Hover(expected, actual protocol.Hover) string {
return cmp.Diff(expected, actual,
cmpopts.IgnoreFields(protocol.Hover{}, "Range"),
cmpopts.IgnoreFields(protocol.MarkupContent{}, "Kind"),
)
}
func CodeAction(expected, actual []protocol.CodeAction) string {
return cmp.Diff(expected, actual)
}
func CompletionList(expected, actual *protocol.CompletionList) string {
return cmp.Diff(expected, actual,
cmpopts.IgnoreFields(protocol.CompletionList{}, "IsIncomplete"),
)
}
func References(expected, actual []protocol.Location) string {
return cmp.Diff(expected, actual)
}
func CompletionListContainsText(cl *protocol.CompletionList, text string) bool {
if cl == nil {
return false
}
for _, item := range cl.Items {
if item.Label == text {
return true
}
}
return false
}

View File

@@ -0,0 +1,131 @@
package lspcmd
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/signal"
"github.com/a-h/templ/cmd/templ/lspcmd/httpdebug"
"github.com/a-h/templ/cmd/templ/lspcmd/pls"
"github.com/a-h/templ/cmd/templ/lspcmd/proxy"
"github.com/a-h/templ/lsp/jsonrpc2"
"github.com/a-h/templ/lsp/protocol"
_ "net/http/pprof"
)
type Arguments struct {
Log string
GoplsLog string
GoplsRPCTrace bool
// PPROF sets whether to start a profiling server on localhost:9999
PPROF bool
// HTTPDebug sets the HTTP endpoint to listen on. Leave empty for no web debug.
HTTPDebug string
}
func Run(stdin io.Reader, stdout, stderr io.Writer, args Arguments) (err error) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
defer func() {
signal.Stop(signalChan)
cancel()
}()
if args.PPROF {
go func() {
_ = http.ListenAndServe("localhost:9999", nil)
}()
}
go func() {
select {
case <-signalChan: // First signal, cancel context.
cancel()
case <-ctx.Done():
}
<-signalChan // Second signal, hard exit.
os.Exit(2)
}()
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
if args.Log != "" {
file, err := os.OpenFile(args.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
return fmt.Errorf("failed to open log file: %w", err)
}
defer file.Close()
// Create a new logger with a file writer
log = slog.New(slog.NewJSONHandler(file, nil))
log.Debug("Logging to file", slog.String("file", args.Log))
}
templStream := jsonrpc2.NewStream(newStdRwc(log, "templStream", stdout, stdin))
return run(ctx, log, templStream, args)
}
func run(ctx context.Context, log *slog.Logger, templStream jsonrpc2.Stream, args Arguments) (err error) {
log.Info("lsp: starting up...")
defer func() {
if r := recover(); r != nil {
log.Error("handled panic", slog.Any("recovered", r))
}
}()
log.Info("lsp: starting gopls...")
rwc, err := pls.NewGopls(ctx, log, pls.Options{
Log: args.GoplsLog,
RPCTrace: args.GoplsRPCTrace,
})
if err != nil {
log.Error("failed to start gopls", slog.Any("error", err))
os.Exit(1)
}
cache := proxy.NewSourceMapCache()
diagnosticCache := proxy.NewDiagnosticCache()
log.Info("creating gopls client")
clientProxy, clientInit := proxy.NewClient(log, cache, diagnosticCache)
_, goplsConn, goplsServer := protocol.NewClient(ctx, clientProxy, jsonrpc2.NewStream(rwc), log)
defer goplsConn.Close()
log.Info("creating proxy")
// Create the proxy to sit between.
serverProxy := proxy.NewServer(log, goplsServer, cache, diagnosticCache)
// Create templ server.
log.Info("creating templ server")
_, templConn, templClient := protocol.NewServer(context.Background(), serverProxy, templStream, log)
defer templConn.Close()
// Allow both the server and the client to initiate outbound requests.
clientInit(templClient)
// Start the web server if required.
if args.HTTPDebug != "" {
log.Info("starting debug http server", slog.String("addr", args.HTTPDebug))
h := httpdebug.NewHandler(log, serverProxy)
go func() {
if err := http.ListenAndServe(args.HTTPDebug, h); err != nil {
log.Error("web server failed", slog.Any("error", err))
}
}()
}
log.Info("listening")
select {
case <-ctx.Done():
log.Info("context closed")
case <-templConn.Done():
log.Info("templConn closed")
case <-goplsConn.Done():
log.Info("goplsConn closed")
}
log.Info("shutdown complete")
return
}

View File

@@ -0,0 +1,124 @@
package pls
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"path"
"runtime"
)
// Options for the gopls client.
type Options struct {
Log string
RPCTrace bool
}
// AsArguments converts the options into command line arguments for gopls.
func (opts Options) AsArguments() []string {
var args []string
if opts.Log != "" {
args = append(args, "-logfile", opts.Log)
}
if opts.RPCTrace {
args = append(args, "-rpc.trace")
}
return args
}
func FindGopls() (location string, err error) {
executableName := "gopls"
if runtime.GOOS == "windows" {
executableName = "gopls.exe"
}
pathLocation, err := exec.LookPath(executableName)
if err == nil {
// Found on the path.
return pathLocation, nil
}
// Unexpected error.
if !errors.Is(err, exec.ErrNotFound) {
return "", fmt.Errorf("unexpected error looking for gopls: %w", err)
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("unexpected error looking for gopls: %w", err)
}
// Probe standard locations.
locations := []string{
path.Join(home, "go", "bin", executableName),
path.Join(home, ".local", "bin", executableName),
}
for _, location := range locations {
_, err = os.Stat(location)
if err != nil {
continue
}
// Found in a standard location.
return location, nil
}
return "", fmt.Errorf("cannot find gopls on the path (%q), in $HOME/go/bin or $HOME/.local/bin/gopls. You can install gopls with `go install golang.org/x/tools/gopls@latest`", os.Getenv("PATH"))
}
// NewGopls starts gopls and opens up a jsonrpc2 connection to it.
func NewGopls(ctx context.Context, log *slog.Logger, opts Options) (rwc io.ReadWriteCloser, err error) {
location, err := FindGopls()
if err != nil {
return nil, err
}
cmd := exec.Command(location, opts.AsArguments()...)
return newProcessReadWriteCloser(log, cmd)
}
// newProcessReadWriteCloser creates a processReadWriteCloser to allow stdin/stdout to be used as
// a JSON RPC 2.0 transport.
func newProcessReadWriteCloser(logger *slog.Logger, cmd *exec.Cmd) (rwc processReadWriteCloser, err error) {
stdin, err := cmd.StdinPipe()
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return
}
rwc = processReadWriteCloser{
in: stdin,
out: stdout,
}
go func() {
if err := cmd.Run(); err != nil {
logger.Error("gopls command error", slog.Any("error", err))
}
}()
return
}
type processReadWriteCloser struct {
in io.WriteCloser
out io.ReadCloser
}
func (prwc processReadWriteCloser) Read(p []byte) (n int, err error) {
return prwc.out.Read(p)
}
func (prwc processReadWriteCloser) Write(p []byte) (n int, err error) {
return prwc.in.Write(p)
}
func (prwc processReadWriteCloser) Close() error {
errInClose := prwc.in.Close()
errOutClose := prwc.out.Close()
if errInClose != nil || errOutClose != nil {
return fmt.Errorf("error closing process - in: %v, out: %v", errInClose, errOutClose)
}
return nil
}

View File

@@ -0,0 +1,143 @@
package proxy
import (
"context"
"fmt"
"log/slog"
"strings"
lsp "github.com/a-h/templ/lsp/protocol"
)
// Client is responsible for rewriting messages that are
// originated from gopls, and are sent to the client.
//
// Since `gopls` is working on Go files, and this is the `templ` LSP,
// the job of this code is to rewrite incoming requests to adjust the
// file name from `*_templ.go` to `*.templ`, and to remap the char
// positions where required.
type Client struct {
Log *slog.Logger
Target lsp.Client
SourceMapCache *SourceMapCache
DiagnosticCache *DiagnosticCache
}
func NewClient(log *slog.Logger, cache *SourceMapCache, diagnosticCache *DiagnosticCache) (c *Client, init func(lsp.Client)) {
c = &Client{
Log: log,
SourceMapCache: cache,
DiagnosticCache: diagnosticCache,
}
return c, func(target lsp.Client) {
c.Target = target
}
}
func (p Client) Progress(ctx context.Context, params *lsp.ProgressParams) (err error) {
p.Log.Info("client <- server: Progress")
return p.Target.Progress(ctx, params)
}
func (p Client) WorkDoneProgressCreate(ctx context.Context, params *lsp.WorkDoneProgressCreateParams) (err error) {
p.Log.Info("client <- server: WorkDoneProgressCreate")
return p.Target.WorkDoneProgressCreate(ctx, params)
}
func (p Client) LogMessage(ctx context.Context, params *lsp.LogMessageParams) (err error) {
p.Log.Info("client <- server: LogMessage", slog.String("message", params.Message))
return p.Target.LogMessage(ctx, params)
}
func (p Client) PublishDiagnostics(ctx context.Context, params *lsp.PublishDiagnosticsParams) (err error) {
p.Log.Info("client <- server: PublishDiagnostics")
if strings.HasSuffix(string(params.URI), "go.mod") {
p.Log.Info("client <- server: PublishDiagnostics: skipping go.mod diagnostics")
return nil
}
// Log diagnostics.
for i, diagnostic := range params.Diagnostics {
p.Log.Info(fmt.Sprintf("client <- server: PublishDiagnostics: [%d]", i), slog.Any("diagnostic", diagnostic))
}
// Get the sourcemap from the cache.
uri := strings.TrimSuffix(string(params.URI), "_templ.go") + ".templ"
sourceMap, ok := p.SourceMapCache.Get(uri)
if !ok {
p.Log.Error("unable to complete because the sourcemap for the URI doesn't exist in the cache", slog.String("uri", uri))
return fmt.Errorf("unable to complete because the sourcemap for %q doesn't exist in the cache, has the didOpen notification been sent yet?", uri)
}
params.URI = lsp.DocumentURI(uri)
// Rewrite the positions.
for i := 0; i < len(params.Diagnostics); i++ {
item := params.Diagnostics[i]
start, ok := sourceMap.SourcePositionFromTarget(item.Range.Start.Line, item.Range.Start.Character)
if !ok {
continue
}
if item.Range.Start.Line == item.Range.End.Line {
length := item.Range.End.Character - item.Range.Start.Character
item.Range.Start.Line = start.Line
item.Range.Start.Character = start.Col
item.Range.End.Line = start.Line
item.Range.End.Character = start.Col + length
params.Diagnostics[i] = item
p.Log.Info(fmt.Sprintf("diagnostic [%d] rewritten", i), slog.Any("diagnostic", item))
continue
}
end, ok := sourceMap.SourcePositionFromTarget(item.Range.End.Line, item.Range.End.Character)
if !ok {
continue
}
item.Range.Start.Line = start.Line
item.Range.Start.Character = start.Col
item.Range.End.Line = end.Line
item.Range.End.Character = end.Col
params.Diagnostics[i] = item
p.Log.Info(fmt.Sprintf("diagnostic [%d] rewritten", i), slog.Any("diagnostic", item))
}
params.Diagnostics = p.DiagnosticCache.AddTemplDiagnostics(uri, params.Diagnostics)
err = p.Target.PublishDiagnostics(ctx, params)
return err
}
func (p Client) ShowMessage(ctx context.Context, params *lsp.ShowMessageParams) (err error) {
p.Log.Info("client <- server: ShowMessage", slog.String("message", params.Message))
if strings.HasPrefix(params.Message, "Do not edit this file!") {
return
}
return p.Target.ShowMessage(ctx, params)
}
func (p Client) ShowMessageRequest(ctx context.Context, params *lsp.ShowMessageRequestParams) (result *lsp.MessageActionItem, err error) {
p.Log.Info("client <- server: ShowMessageRequest", slog.String("message", params.Message))
return p.Target.ShowMessageRequest(ctx, params)
}
func (p Client) Telemetry(ctx context.Context, params any) (err error) {
p.Log.Info("client <- server: Telemetry")
return p.Target.Telemetry(ctx, params)
}
func (p Client) RegisterCapability(ctx context.Context, params *lsp.RegistrationParams) (err error) {
p.Log.Info("client <- server: RegisterCapability")
return p.Target.RegisterCapability(ctx, params)
}
func (p Client) UnregisterCapability(ctx context.Context, params *lsp.UnregistrationParams) (err error) {
p.Log.Info("client <- server: UnregisterCapability")
return p.Target.UnregisterCapability(ctx, params)
}
func (p Client) ApplyEdit(ctx context.Context, params *lsp.ApplyWorkspaceEditParams) (result *lsp.ApplyWorkspaceEditResponse, err error) {
p.Log.Info("client <- server: ApplyEdit")
return p.Target.ApplyEdit(ctx, params)
}
func (p Client) Configuration(ctx context.Context, params *lsp.ConfigurationParams) (result []any, err error) {
p.Log.Info("client <- server: Configuration")
return p.Target.Configuration(ctx, params)
}
func (p Client) WorkspaceFolders(ctx context.Context) (result []lsp.WorkspaceFolder, err error) {
p.Log.Info("client <- server: WorkspaceFolders")
return p.Target.WorkspaceFolders(ctx)
}

View File

@@ -0,0 +1,61 @@
package proxy
import (
"sync"
lsp "github.com/a-h/templ/lsp/protocol"
)
func NewDiagnosticCache() *DiagnosticCache {
return &DiagnosticCache{
m: &sync.Mutex{},
cache: make(map[string]fileDiagnostic),
}
}
type fileDiagnostic struct {
templDiagnostics []lsp.Diagnostic
goplsDiagnostics []lsp.Diagnostic
}
type DiagnosticCache struct {
m *sync.Mutex
cache map[string]fileDiagnostic
}
func zeroLengthSliceIfNil(diags []lsp.Diagnostic) []lsp.Diagnostic {
if diags == nil {
return make([]lsp.Diagnostic, 0)
}
return diags
}
func (dc *DiagnosticCache) AddTemplDiagnostics(uri string, goDiagnostics []lsp.Diagnostic) []lsp.Diagnostic {
goDiagnostics = zeroLengthSliceIfNil(goDiagnostics)
dc.m.Lock()
defer dc.m.Unlock()
diag := dc.cache[uri]
diag.goplsDiagnostics = goDiagnostics
diag.templDiagnostics = zeroLengthSliceIfNil(diag.templDiagnostics)
dc.cache[uri] = diag
return append(diag.templDiagnostics, goDiagnostics...)
}
func (dc *DiagnosticCache) ClearTemplDiagnostics(uri string) {
dc.m.Lock()
defer dc.m.Unlock()
diag := dc.cache[uri]
diag.templDiagnostics = make([]lsp.Diagnostic, 0)
dc.cache[uri] = diag
}
func (dc *DiagnosticCache) AddGoDiagnostics(uri string, templDiagnostics []lsp.Diagnostic) []lsp.Diagnostic {
templDiagnostics = zeroLengthSliceIfNil(templDiagnostics)
dc.m.Lock()
defer dc.m.Unlock()
diag := dc.cache[uri]
diag.templDiagnostics = templDiagnostics
diag.goplsDiagnostics = zeroLengthSliceIfNil(diag.goplsDiagnostics)
dc.cache[uri] = diag
return append(diag.goplsDiagnostics, templDiagnostics...)
}

View File

@@ -0,0 +1,215 @@
package proxy
import (
"fmt"
"log/slog"
"strings"
"sync"
lsp "github.com/a-h/templ/lsp/protocol"
)
// newDocumentContents creates a document content processing tool.
func newDocumentContents(log *slog.Logger) *DocumentContents {
return &DocumentContents{
m: new(sync.Mutex),
uriToContents: make(map[string]*Document),
log: log,
}
}
type DocumentContents struct {
m *sync.Mutex
uriToContents map[string]*Document
log *slog.Logger
}
// Set the contents of a document.
func (dc *DocumentContents) Set(uri string, d *Document) {
dc.m.Lock()
defer dc.m.Unlock()
dc.uriToContents[uri] = d
}
// Get the contents of a document.
func (dc *DocumentContents) Get(uri string) (d *Document, ok bool) {
dc.m.Lock()
defer dc.m.Unlock()
d, ok = dc.uriToContents[uri]
return
}
// Delete a document from memory.
func (dc *DocumentContents) Delete(uri string) {
dc.m.Lock()
defer dc.m.Unlock()
delete(dc.uriToContents, uri)
}
func (dc *DocumentContents) URIs() (uris []string) {
dc.m.Lock()
defer dc.m.Unlock()
uris = make([]string, len(dc.uriToContents))
var i int
for k := range dc.uriToContents {
uris[i] = k
i++
}
return uris
}
// Apply changes to the document from the client, and return a list of change requests to send back to the client.
func (dc *DocumentContents) Apply(uri string, changes []lsp.TextDocumentContentChangeEvent) (d *Document, err error) {
dc.m.Lock()
defer dc.m.Unlock()
var ok bool
d, ok = dc.uriToContents[uri]
if !ok {
err = fmt.Errorf("document not found")
return
}
for _, change := range changes {
d.Apply(change.Range, change.Text)
}
return
}
func NewDocument(log *slog.Logger, s string) *Document {
return &Document{
Log: log,
Lines: strings.Split(s, "\n"),
}
}
type Document struct {
Log *slog.Logger
Lines []string
}
func (d *Document) LineLengths() (lens []int) {
lens = make([]int, len(d.Lines))
for i, l := range d.Lines {
lens[i] = len(l)
}
return
}
func (d *Document) Len() (line, col int) {
line = len(d.Lines)
col = len(d.Lines[len(d.Lines)-1])
return
}
func (d *Document) Overwrite(fromLine, fromCol, toLine, toCol int, lines []string) {
suffix := d.Lines[toLine][toCol:]
toLen := d.LineLengths()[toLine]
d.Delete(fromLine, fromCol, toLine, toLen)
lines[len(lines)-1] = lines[len(lines)-1] + suffix
d.Insert(fromLine, fromCol, lines)
}
func (d *Document) Insert(line, col int, lines []string) {
prefix := d.Lines[line][:col]
suffix := d.Lines[line][col:]
lines[0] = prefix + lines[0]
d.Lines[line] = lines[0]
if len(lines) > 1 {
d.InsertLines(line+1, lines[1:])
}
d.Lines[line+len(lines)-1] = lines[len(lines)-1] + suffix
}
func (d *Document) InsertLines(i int, withLines []string) {
d.Lines = append(d.Lines[:i], append(withLines, d.Lines[i:]...)...)
}
func (d *Document) Delete(fromLine, fromCol, toLine, toCol int) {
prefix := d.Lines[fromLine][:fromCol]
suffix := d.Lines[toLine][toCol:]
// Delete intermediate lines.
deleteFrom := fromLine
deleteTo := fromLine + (toLine - fromLine)
d.DeleteLines(deleteFrom, deleteTo)
// Merge the contents of the final line.
d.Lines[fromLine] = prefix + suffix
}
func (d *Document) DeleteLines(i, j int) {
d.Lines = append(d.Lines[:i], d.Lines[j:]...)
}
func (d *Document) String() string {
return strings.Join(d.Lines, "\n")
}
func (d *Document) Replace(with string) {
d.Lines = strings.Split(with, "\n")
}
func (d *Document) Apply(r *lsp.Range, with string) {
withLines := strings.Split(with, "\n")
d.normalize(r)
if d.isWholeDocument(r) {
d.Lines = withLines
return
}
if d.isInsert(r, with) {
d.Insert(int(r.Start.Line), int(r.Start.Character), withLines)
return
}
if d.isDelete(r, with) {
d.Delete(int(r.Start.Line), int(r.Start.Character), int(r.End.Line), int(r.End.Character))
return
}
if d.isOverwrite(r, with) {
d.Overwrite(int(r.Start.Line), int(r.Start.Character), int(r.End.Line), int(r.End.Character), withLines)
}
}
func (d *Document) normalize(r *lsp.Range) {
if r == nil {
return
}
lens := d.LineLengths()
if r.Start.Line >= uint32(len(lens)) {
r.Start.Line = uint32(len(lens) - 1)
r.Start.Character = uint32(lens[r.Start.Line])
}
if r.Start.Character > uint32(lens[r.Start.Line]) {
r.Start.Character = uint32(lens[r.Start.Line])
}
if r.End.Line >= uint32(len(lens)) {
r.End.Line = uint32(len(lens) - 1)
r.End.Character = uint32(lens[r.End.Line])
}
if r.End.Character > uint32(lens[r.End.Line]) {
r.End.Character = uint32(lens[r.End.Line])
}
}
func (d *Document) isOverwrite(r *lsp.Range, with string) bool {
return (r.End.Line != r.Start.Line || r.Start.Character != r.End.Character) && with != ""
}
func (d *Document) isInsert(r *lsp.Range, with string) bool {
return r.End.Line == r.Start.Line && r.Start.Character == r.End.Character && with != ""
}
func (d *Document) isDelete(r *lsp.Range, with string) bool {
return (r.End.Line != r.Start.Line || r.Start.Character != r.End.Character) && with == ""
}
func (d *Document) isWholeDocument(r *lsp.Range) bool {
if r == nil {
return true
}
if r.Start.Line != 0 || r.Start.Character != 0 {
return false
}
l, c := d.Len()
return r.End.Line == uint32(l) || r.End.Character == uint32(c)
}

View File

@@ -0,0 +1,571 @@
package proxy
import (
"log/slog"
"os"
"testing"
lsp "github.com/a-h/templ/lsp/protocol"
"github.com/google/go-cmp/cmp"
)
func TestDocument(t *testing.T) {
tests := []struct {
name string
start string
operations []func(d *Document)
expected string
}{
{
name: "Replace all content if the range is nil",
start: "0\n1\n2",
operations: []func(d *Document){
func(d *Document) {
d.Apply(nil, "replaced")
},
},
expected: "replaced",
},
{
name: "If the range matches the length of the file, all of it is replaced",
start: "0\n1\n2",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 1,
},
}, "replaced")
},
},
expected: "replaced",
},
{
name: "Can insert new text",
start: ``,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 0,
},
}, "abc")
},
},
expected: "abc",
},
{
name: "Can insert new text that ends with a newline",
start: ``,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 0,
},
}, "abc\n")
},
},
expected: `abc
`,
},
{
name: "Can insert a new line at the end of existing text",
start: `abc
`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 3,
},
End: lsp.Position{
Line: 0,
Character: 3,
},
}, "\n")
},
},
expected: `abc
`,
},
{
name: "Can insert a word at the start of existing text",
start: `bc`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 0,
},
}, "a")
},
},
expected: `abc`,
},
{
name: "Can remove whole line",
start: "0\n1\n2",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 0,
},
}, "")
},
},
expected: "0\n2",
},
{
name: "Can remove line prefix",
start: "abcdef",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 3,
},
}, "")
},
},
expected: "def",
},
{
name: "Can remove line substring",
start: "abcdef",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 2,
},
End: lsp.Position{
Line: 0,
Character: 3,
},
}, "")
},
},
expected: "abdef",
},
{
name: "Can remove line suffix",
start: "abcdef",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 4,
},
End: lsp.Position{
Line: 0,
Character: 6,
},
}, "")
},
},
expected: "abcd",
},
{
name: "Can remove across lines",
start: "0\n1\n22",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 1,
},
}, "")
},
},
expected: "0\n2",
},
{
name: "Can remove part of two lines",
start: "Line one\nLine two\nLine three",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 4,
},
End: lsp.Position{
Line: 2,
Character: 4,
},
}, "")
},
},
expected: "Line three",
},
{
name: "Can remove all lines",
start: "0\n1\n2",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 1,
},
}, "")
},
},
expected: "",
},
{
name: "Can replace line prefix",
start: "012345",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 3,
},
}, "ABCDEFG")
},
},
expected: "ABCDEFG345",
},
{
name: "Can replace text across line boundaries",
start: "Line one\nLine two\nLine three",
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 4,
},
End: lsp.Position{
Line: 2,
Character: 4,
},
}, " one test\nNew Line 2\nNew line")
},
},
expected: "Line one test\nNew Line 2\nNew line three",
},
{
name: "Can add new line to end of single line",
start: `a`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 1,
},
End: lsp.Position{
Line: 0,
Character: 1,
},
}, "\nb")
},
},
expected: "a\nb",
},
{
name: "Exceeding the col and line count rounds down to the end of the file",
start: `a`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 200,
Character: 600,
},
End: lsp.Position{
Line: 300,
Character: 1200,
},
}, "\nb")
},
},
expected: "a\nb",
},
{
name: "Can remove a line and add it back from the end of the previous line (insert)",
start: "a\nb\nc",
operations: []func(d *Document){
func(d *Document) {
// Delete.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 0,
},
}, "")
// Put it back.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 1,
},
End: lsp.Position{
Line: 0,
Character: 1,
},
}, "\nb")
},
},
expected: "a\nb\nc",
},
{
name: "Can remove a line and add it back from the end of the previous line (overwrite)",
start: "a\nb\nc",
operations: []func(d *Document){
func(d *Document) {
// Delete.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 0,
},
}, "")
// Put it back.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 1,
},
End: lsp.Position{
Line: 1,
Character: 0,
},
}, "\nb\n")
},
},
expected: "a\nb\nc",
},
{
name: "Add new line with indent to the end of the line",
// Based on log entry.
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
start: `package testcall
templ personTemplate(p person) {
<div>
<h1>{ p.name }</h1>
</div>
}
`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 4,
Character: 21,
},
End: lsp.Position{
Line: 4,
Character: 21,
},
}, "\n\t\t")
},
},
expected: `package testcall
templ personTemplate(p person) {
<div>
<h1>{ p.name }</h1>
</div>
}
`,
},
{
name: "Recreate error smaller",
// Based on log entry.
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
start: "line1\n\t\tline2\nline3",
operations: []func(d *Document){
func(d *Document) {
// Remove \t\tline2
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 0,
},
}, "")
// Put it back.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 5,
},
End: lsp.Position{
Line: 1,
Character: 0,
},
},
"\n\t\tline2\n")
},
},
expected: "line1\n\t\tline2\nline3",
},
{
name: "Recreate error",
// Based on log entry.
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
start: ` <footer data-testid="footerTemplate">
<div>&copy; { fmt.Sprintf("%d", time.Now().Year()) }</div>
</footer>
}
`,
operations: []func(d *Document){
func(d *Document) {
// Remove <div>&copy; { fmt.Sprintf("%d", time.Now().Year()) }</div>
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 2,
Character: 0,
},
}, "")
// Put it back.
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 38,
},
End: lsp.Position{
Line: 1,
Character: 0,
},
},
"\n\t\t<div>&copy; { fmt.Sprintf(\"%d\", time.Now().Year()) }</div>\n")
},
},
expected: ` <footer data-testid="footerTemplate">
<div>&copy; { fmt.Sprintf("%d", time.Now().Year()) }</div>
</footer>
}
`,
},
{
name: "Insert at start of line",
// Based on log entry.
// {"level":"info","ts":"2023-03-25T17:17:38Z","caller":"proxy/server.go:393","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":5},"contentChanges":[{"range":{"start":{"line":6,"character":0},"end":{"line":6,"character":0}},"text":"a"}]}}
start: `b`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 0,
},
End: lsp.Position{
Line: 0,
Character: 0,
},
}, "a")
},
},
expected: `ab`,
},
{
name: "Insert full new line",
start: `a
c
d`,
operations: []func(d *Document){
func(d *Document) {
d.Apply(&lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 1,
Character: 0,
},
}, "b\n")
},
},
expected: `a
b
c
d`,
},
}
for _, tt := range tests {
logger := slog.New(slog.NewJSONHandler(os.Stderr, nil))
t.Run(tt.name, func(t *testing.T) {
d := NewDocument(logger, tt.start)
for _, f := range tt.operations {
f(d)
}
actual := d.String()
if diff := cmp.Diff(tt.expected, actual); diff != "" {
t.Error(diff)
}
})
}
}

View File

@@ -0,0 +1,293 @@
package proxy
import (
"fmt"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestFindLastImport(t *testing.T) {
tests := []struct {
name string
templContents string
packageName string
expected string
}{
{
name: "if there are no imports, add a single line import",
templContents: `package main
templ example() {
}
`,
packageName: "strings",
expected: `package main
import "strings"
templ example() {
}
`,
},
{
name: "if there is an existing single-line imports, add one at the end",
templContents: `package main
import "strings"
templ example() {
}
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
templ example() {
}
`,
},
{
name: "if there are multiple existing single-line imports, add one at the end",
templContents: `package main
import "strings"
import "fmt"
templ example() {
}
`,
packageName: "time",
expected: `package main
import "strings"
import "fmt"
import "time"
templ example() {
}
`,
},
{
name: "if there are existing multi-line imports, add one at the end",
templContents: `package main
import (
"strings"
)
templ example() {
}
`,
packageName: "fmt",
expected: `package main
import (
"strings"
"fmt"
)
templ example() {
}
`,
},
{
name: "ignore imports that happen after templates",
templContents: `package main
import "strings"
templ example() {
}
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
templ example() {
}
import "other"
`,
},
{
name: "ignore imports that happen after funcs in the file",
templContents: `package main
import "strings"
func example() {
}
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
func example() {
}
import "other"
`,
},
{
name: "ignore imports that happen after css expressions in the file",
templContents: `package main
import "strings"
css example() {
}
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
css example() {
}
import "other"
`,
},
{
name: "ignore imports that happen after script expressions in the file",
templContents: `package main
import "strings"
script example() {
}
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
script example() {
}
import "other"
`,
},
{
name: "ignore imports that happen after var expressions in the file",
templContents: `package main
import "strings"
var s string
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
var s string
import "other"
`,
},
{
name: "ignore imports that happen after const expressions in the file",
templContents: `package main
import "strings"
const s = "test"
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
const s = "test"
import "other"
`,
},
{
name: "ignore imports that happen after type expressions in the file",
templContents: `package main
import "strings"
type Value int
import "other"
`,
packageName: "fmt",
expected: `package main
import "strings"
import "fmt"
type Value int
import "other"
`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
lines := strings.Split(test.templContents, "\n")
imp := addImport(lines, fmt.Sprintf("%q", test.packageName))
textWithoutNewline := strings.TrimSuffix(imp.Text, "\n")
actualLines := append(lines[:imp.LineIndex], append([]string{textWithoutNewline}, lines[imp.LineIndex:]...)...)
actual := strings.Join(actualLines, "\n")
if diff := cmp.Diff(test.expected, actual); diff != "" {
t.Error(diff)
}
})
}
}
func TestGetPackageFromItemDetail(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: `"fmt"`,
expected: `"fmt"`,
},
{
input: `func(state fmt.State, verb rune) string (from "fmt")`,
expected: `"fmt"`,
},
{
input: `non matching`,
expected: `non matching`,
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
actual := getPackageFromItemDetail(test.input)
if test.expected != actual {
t.Errorf("expected %q, got %q", test.expected, actual)
}
})
}
}

View File

@@ -0,0 +1,24 @@
package proxy
import (
"path"
"strings"
lsp "github.com/a-h/templ/lsp/protocol"
)
func convertTemplToGoURI(templURI lsp.DocumentURI) (isTemplFile bool, goURI lsp.DocumentURI) {
base, fileName := path.Split(string(templURI))
if !strings.HasSuffix(fileName, ".templ") {
return
}
return true, lsp.DocumentURI(base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go"))
}
func convertTemplGoToTemplURI(goURI lsp.DocumentURI) (isTemplGoFile bool, templURI lsp.DocumentURI) {
base, fileName := path.Split(string(goURI))
if !strings.HasSuffix(fileName, "_templ.go") {
return
}
return true, lsp.DocumentURI(base + (strings.TrimSuffix(fileName, "_templ.go") + ".templ"))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,111 @@
package proxy
import lsp "github.com/a-h/templ/lsp/protocol"
var htmlSnippets = []lsp.CompletionItem{
{
Label: "<?>",
InsertText: `${1}>
${0}
</${1}>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "a",
InsertText: `a href="${1:}">${2:}</a>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "button",
InsertText: `button type="button" ${1:}>${2:}</button>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "div",
InsertText: `div>
${0}
</div>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "p",
InsertText: `p>
${0}
</p>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "head",
InsertText: `head>
${0}
</head>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "body",
InsertText: `body>
${0}
</body>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "title",
InsertText: `title>${0}</title>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h1",
InsertText: `h1>${0}</h1>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h2",
InsertText: `h2>${0}</h2>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h3",
InsertText: `h3>${0}</h3>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h4",
InsertText: `h4>${0}</h4>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h5",
InsertText: `h5>${0}</h5>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
{
Label: "h6",
InsertText: `h6>${0}</h6>`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
}
var snippet = []lsp.CompletionItem{
{
Label: "templ",
InsertText: `templ ${2:TemplateName}() {
${0}
}`,
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
InsertTextFormat: lsp.InsertTextFormatSnippet,
},
}

View File

@@ -0,0 +1,52 @@
package proxy
import (
"sync"
"github.com/a-h/templ/parser/v2"
)
// NewSourceMapCache creates a cache of .templ file URIs to the source map.
func NewSourceMapCache() *SourceMapCache {
return &SourceMapCache{
m: new(sync.Mutex),
uriToSourceMap: make(map[string]*parser.SourceMap),
}
}
// SourceMapCache is a cache of .templ file URIs to the source map.
type SourceMapCache struct {
m *sync.Mutex
uriToSourceMap map[string]*parser.SourceMap
}
func (fc *SourceMapCache) Set(uri string, m *parser.SourceMap) {
fc.m.Lock()
defer fc.m.Unlock()
fc.uriToSourceMap[uri] = m
}
func (fc *SourceMapCache) Get(uri string) (m *parser.SourceMap, ok bool) {
fc.m.Lock()
defer fc.m.Unlock()
m, ok = fc.uriToSourceMap[uri]
return
}
func (fc *SourceMapCache) Delete(uri string) {
fc.m.Lock()
defer fc.m.Unlock()
delete(fc.uriToSourceMap, uri)
}
func (fc *SourceMapCache) URIs() (uris []string) {
fc.m.Lock()
defer fc.m.Unlock()
uris = make([]string, len(fc.uriToSourceMap))
var i int
for k := range fc.uriToSourceMap {
uris[i] = k
i++
}
return uris
}

View File

@@ -0,0 +1,50 @@
package lspcmd
import (
"errors"
"io"
"log/slog"
)
// stdrwc (standard read/write closer) reads from stdin, and writes to stdout.
func newStdRwc(log *slog.Logger, name string, w io.Writer, r io.Reader) stdrwc {
return stdrwc{
log: log,
name: name,
w: w,
r: r,
}
}
type stdrwc struct {
log *slog.Logger
name string
w io.Writer
r io.Reader
}
func (s stdrwc) Read(p []byte) (int, error) {
return s.r.Read(p)
}
func (s stdrwc) Write(p []byte) (int, error) {
return s.w.Write(p)
}
func (s stdrwc) Close() error {
s.log.Info("rwc: closing", slog.String("name", s.name))
var errs []error
if closer, isCloser := s.r.(io.Closer); isCloser {
if err := closer.Close(); err != nil {
s.log.Error("rwc: error closing reader", slog.String("name", s.name), slog.Any("error", err))
errs = append(errs, err)
}
}
if closer, isCloser := s.w.(io.Closer); isCloser {
if err := closer.Close(); err != nil {
s.log.Error("rwc: error closing writer", slog.String("name", s.name), slog.Any("error", err))
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

394
templ/cmd/templ/main.go Normal file
View File

@@ -0,0 +1,394 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
"runtime"
"github.com/a-h/templ"
"github.com/a-h/templ/cmd/templ/fmtcmd"
"github.com/a-h/templ/cmd/templ/generatecmd"
"github.com/a-h/templ/cmd/templ/infocmd"
"github.com/a-h/templ/cmd/templ/lspcmd"
"github.com/a-h/templ/cmd/templ/sloghandler"
"github.com/fatih/color"
)
func main() {
code := run(os.Stdin, os.Stdout, os.Stderr, os.Args)
if code != 0 {
os.Exit(code)
}
}
const usageText = `usage: templ <command> [<args>...]
templ - build HTML UIs with Go
See docs at https://templ.guide
commands:
generate Generates Go code from templ files
fmt Formats templ files
lsp Starts a language server for templ files
info Displays information about the templ environment
version Prints the version
`
func run(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
if len(args) < 2 {
fmt.Fprint(stderr, usageText)
return 64 // EX_USAGE
}
switch args[1] {
case "info":
return infoCmd(stdout, stderr, args[2:])
case "generate":
return generateCmd(stdout, stderr, args[2:])
case "fmt":
return fmtCmd(stdin, stdout, stderr, args[2:])
case "lsp":
return lspCmd(stdin, stdout, stderr, args[2:])
case "version", "--version":
fmt.Fprintln(stdout, templ.Version())
return 0
case "help", "-help", "--help", "-h":
fmt.Fprint(stdout, usageText)
return 0
}
fmt.Fprint(stderr, usageText)
return 64 // EX_USAGE
}
func newLogger(logLevel string, verbose bool, stderr io.Writer) *slog.Logger {
if verbose {
logLevel = "debug"
}
level := slog.LevelInfo.Level()
switch logLevel {
case "debug":
level = slog.LevelDebug.Level()
case "warn":
level = slog.LevelWarn.Level()
case "error":
level = slog.LevelError.Level()
}
return slog.New(sloghandler.NewHandler(stderr, &slog.HandlerOptions{
AddSource: logLevel == "debug",
Level: level,
}))
}
const infoUsageText = `usage: templ info [<args>...]
Displays information about the templ environment.
Args:
-json
Output information in JSON format to stdout. (default false)
-v
Set log verbosity level to "debug". (default "info")
-log-level
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
-help
Print help and exit.
`
func infoCmd(stdout, stderr io.Writer, args []string) (code int) {
cmd := flag.NewFlagSet("diagnose", flag.ExitOnError)
jsonFlag := cmd.Bool("json", false, "")
verboseFlag := cmd.Bool("v", false, "")
logLevelFlag := cmd.String("log-level", "info", "")
helpFlag := cmd.Bool("help", false, "")
err := cmd.Parse(args)
if err != nil {
fmt.Fprint(stderr, infoUsageText)
return 64 // EX_USAGE
}
if *helpFlag {
fmt.Fprint(stdout, infoUsageText)
return
}
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
ctx, cancel := context.WithCancel(context.Background())
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
go func() {
<-signalChan
fmt.Fprintln(stderr, "Stopping...")
cancel()
}()
err = infocmd.Run(ctx, log, stdout, infocmd.Arguments{
JSON: *jsonFlag,
})
if err != nil {
color.New(color.FgRed).Fprint(stderr, "(✗) ")
fmt.Fprintln(stderr, "Command failed: "+err.Error())
return 1
}
return 0
}
const generateUsageText = `usage: templ generate [<args>...]
Generates Go code from templ files.
Args:
-path <path>
Generates code for all files in path. (default .)
-f <file>
Optionally generates code for a single file, e.g. -f header.templ
-stdout
Prints to stdout instead of writing generated files to the filesystem.
Only applicable when -f is used.
-source-map-visualisations
Set to true to generate HTML files to visualise the templ code and its corresponding Go code.
-include-version
Set to false to skip inclusion of the templ version in the generated code. (default true)
-include-timestamp
Set to true to include the current time in the generated code.
-watch
Set to true to watch the path for changes and regenerate code.
-watch-pattern <regexp>
Set the regexp pattern of files that will be watched for changes. (default: '(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)')
-cmd <cmd>
Set the command to run after generating code.
-proxy
Set the URL to proxy after generating code and executing the command.
-proxyport
The port the proxy will listen on. (default 7331)
-proxybind
The address the proxy will listen on. (default 127.0.0.1)
-notify-proxy
If present, the command will issue a reload event to the proxy 127.0.0.1:7331, or use proxyport and proxybind to specify a different address.
-w
Number of workers to use when generating code. (default runtime.NumCPUs)
-lazy
Only generate .go files if the source .templ file is newer.
-pprof
Port to run the pprof server on.
-keep-orphaned-files
Keeps orphaned generated templ files. (default false)
-v
Set log verbosity level to "debug". (default "info")
-log-level
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
-help
Print help and exit.
Examples:
Generate code for all files in the current directory and subdirectories:
templ generate
Generate code for a single file:
templ generate -f header.templ
Watch the current directory and subdirectories for changes and regenerate code:
templ generate -watch
`
func generateCmd(stdout, stderr io.Writer, args []string) (code int) {
cmd := flag.NewFlagSet("generate", flag.ExitOnError)
fileNameFlag := cmd.String("f", "", "")
pathFlag := cmd.String("path", ".", "")
toStdoutFlag := cmd.Bool("stdout", false, "")
sourceMapVisualisationsFlag := cmd.Bool("source-map-visualisations", false, "")
includeVersionFlag := cmd.Bool("include-version", true, "")
includeTimestampFlag := cmd.Bool("include-timestamp", false, "")
watchFlag := cmd.Bool("watch", false, "")
watchPatternFlag := cmd.String("watch-pattern", "(.+\\.go$)|(.+\\.templ$)|(.+_templ\\.txt$)", "")
openBrowserFlag := cmd.Bool("open-browser", true, "")
cmdFlag := cmd.String("cmd", "", "")
proxyFlag := cmd.String("proxy", "", "")
proxyPortFlag := cmd.Int("proxyport", 7331, "")
proxyBindFlag := cmd.String("proxybind", "127.0.0.1", "")
notifyProxyFlag := cmd.Bool("notify-proxy", false, "")
workerCountFlag := cmd.Int("w", runtime.NumCPU(), "")
pprofPortFlag := cmd.Int("pprof", 0, "")
keepOrphanedFilesFlag := cmd.Bool("keep-orphaned-files", false, "")
verboseFlag := cmd.Bool("v", false, "")
logLevelFlag := cmd.String("log-level", "info", "")
lazyFlag := cmd.Bool("lazy", false, "")
helpFlag := cmd.Bool("help", false, "")
err := cmd.Parse(args)
if err != nil {
fmt.Fprint(stderr, generateUsageText)
return 64 // EX_USAGE
}
if *helpFlag {
fmt.Fprint(stdout, generateUsageText)
return
}
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
ctx, cancel := context.WithCancel(context.Background())
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
go func() {
<-signalChan
fmt.Fprintln(stderr, "Stopping...")
cancel()
}()
var fw generatecmd.FileWriterFunc
if *toStdoutFlag {
fw = generatecmd.WriterFileWriter(stdout)
}
err = generatecmd.Run(ctx, log, generatecmd.Arguments{
FileName: *fileNameFlag,
Path: *pathFlag,
FileWriter: fw,
Watch: *watchFlag,
WatchPattern: *watchPatternFlag,
OpenBrowser: *openBrowserFlag,
Command: *cmdFlag,
Proxy: *proxyFlag,
ProxyPort: *proxyPortFlag,
ProxyBind: *proxyBindFlag,
NotifyProxy: *notifyProxyFlag,
WorkerCount: *workerCountFlag,
GenerateSourceMapVisualisations: *sourceMapVisualisationsFlag,
IncludeVersion: *includeVersionFlag,
IncludeTimestamp: *includeTimestampFlag,
PPROFPort: *pprofPortFlag,
KeepOrphanedFiles: *keepOrphanedFilesFlag,
Lazy: *lazyFlag,
})
if err != nil {
color.New(color.FgRed).Fprint(stderr, "(✗) ")
fmt.Fprintln(stderr, "Command failed: "+err.Error())
return 1
}
return 0
}
const fmtUsageText = `usage: templ fmt [<args> ...]
Format all files in directory:
templ fmt .
Format stdin to stdout:
templ fmt < header.templ
Format file or directory to stdout:
templ fmt -stdout FILE
Args:
-stdout
Prints to stdout instead of in-place format
-stdin-filepath
Provides the formatter with filepath context when using -stdout.
Required for organising imports.
-v
Set log verbosity level to "debug". (default "info")
-log-level
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
-w
Number of workers to use when formatting code. (default runtime.NumCPUs).
-fail
Fails with exit code 1 if files are changed. (e.g. in CI)
-help
Print help and exit.
`
func fmtCmd(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
cmd := flag.NewFlagSet("fmt", flag.ExitOnError)
helpFlag := cmd.Bool("help", false, "")
workerCountFlag := cmd.Int("w", runtime.NumCPU(), "")
verboseFlag := cmd.Bool("v", false, "")
logLevelFlag := cmd.String("log-level", "info", "")
failIfChanged := cmd.Bool("fail", false, "")
stdoutFlag := cmd.Bool("stdout", false, "")
stdinFilepath := cmd.String("stdin-filepath", "", "")
err := cmd.Parse(args)
if err != nil {
fmt.Fprint(stderr, fmtUsageText)
return 64 // EX_USAGE
}
if *helpFlag {
fmt.Fprint(stdout, fmtUsageText)
return
}
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
err = fmtcmd.Run(log, stdin, stdout, fmtcmd.Arguments{
ToStdout: *stdoutFlag,
Files: cmd.Args(),
WorkerCount: *workerCountFlag,
StdinFilepath: *stdinFilepath,
FailIfChanged: *failIfChanged,
})
if err != nil {
return 1
}
return 0
}
const lspUsageText = `usage: templ lsp [<args> ...]
Starts a language server for templ.
Args:
-log string
The file to log templ LSP output to, or leave empty to disable logging.
-goplsLog string
The file to log gopls output, or leave empty to disable logging.
-goplsRPCTrace
Set gopls to log input and output messages.
-help
Print help and exit.
-pprof
Enable pprof web server (default address is localhost:9999)
-http string
Enable http debug server by setting a listen address (e.g. localhost:7474)
`
func lspCmd(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
cmd := flag.NewFlagSet("lsp", flag.ExitOnError)
logFlag := cmd.String("log", "", "")
goplsLog := cmd.String("goplsLog", "", "")
goplsRPCTrace := cmd.Bool("goplsRPCTrace", false, "")
helpFlag := cmd.Bool("help", false, "")
pprofFlag := cmd.Bool("pprof", false, "")
httpDebugFlag := cmd.String("http", "", "")
err := cmd.Parse(args)
if err != nil {
fmt.Fprint(stderr, lspUsageText)
return 64 // EX_USAGE
}
if *helpFlag {
fmt.Fprint(stdout, lspUsageText)
return
}
err = lspcmd.Run(stdin, stdout, stderr, lspcmd.Arguments{
Log: *logFlag,
GoplsLog: *goplsLog,
GoplsRPCTrace: *goplsRPCTrace,
PPROF: *pprofFlag,
HTTPDebug: *httpDebugFlag,
})
if err != nil {
fmt.Fprintln(stderr, err.Error())
return 1
}
return 0
}

View File

@@ -0,0 +1,102 @@
package main
import (
"bytes"
"strings"
"testing"
"github.com/a-h/templ"
"github.com/google/go-cmp/cmp"
)
func TestMain(t *testing.T) {
tests := []struct {
name string
args []string
expectedStdout string
expectedStderr string
expectedCode int
}{
{
name: "no args prints usage",
args: []string{},
expectedStderr: usageText,
expectedCode: 64, // EX_USAGE
},
{
name: `"templ help" prints help`,
args: []string{"templ", "help"},
expectedStdout: usageText,
expectedCode: 0,
},
{
name: `"templ --help" prints help`,
args: []string{"templ", "--help"},
expectedStdout: usageText,
expectedCode: 0,
},
{
name: `"templ version" prints version`,
args: []string{"templ", "version"},
expectedStdout: templ.Version() + "\n",
expectedCode: 0,
},
{
name: `"templ --version" prints version`,
args: []string{"templ", "--version"},
expectedStdout: templ.Version() + "\n",
expectedCode: 0,
},
{
name: `"templ fmt --help" prints usage`,
args: []string{"templ", "fmt", "--help"},
expectedStdout: fmtUsageText,
expectedCode: 0,
},
{
name: `"templ generate --help" prints usage`,
args: []string{"templ", "generate", "--help"},
expectedStdout: generateUsageText,
expectedCode: 0,
},
{
name: `"templ lsp --help" prints usage`,
args: []string{"templ", "lsp", "--help"},
expectedStdout: lspUsageText,
expectedCode: 0,
},
{
name: `"templ info --help" prints usage`,
args: []string{"templ", "info", "--help"},
expectedStdout: infoUsageText,
expectedCode: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stdin := strings.NewReader("")
stdout := bytes.NewBuffer(nil)
stderr := bytes.NewBuffer(nil)
actualCode := run(stdin, stdout, stderr, test.args)
if actualCode != test.expectedCode {
t.Errorf("expected code %v, got %v", test.expectedCode, actualCode)
}
if diff := cmp.Diff(test.expectedStdout, stdout.String()); diff != "" {
t.Error(diff)
t.Error("expected stdout:")
t.Error(test.expectedStdout)
t.Error("actual stdout:")
t.Error(stdout.String())
}
if diff := cmp.Diff(test.expectedStderr, stderr.String()); diff != "" {
t.Error(diff)
t.Error("expected stderr:")
t.Error(test.expectedStderr)
t.Error("actual stderr:")
t.Error(stderr.String())
}
})
}
}

View File

@@ -0,0 +1,80 @@
package processor
import (
"io/fs"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
type Result struct {
FileName string
Duration time.Duration
Error error
ChangesMade bool
}
func Process(dir string, f func(fileName string) (error, bool), workerCount int, results chan<- Result) {
templates := make(chan string)
go func() {
defer close(templates)
if err := FindTemplates(dir, templates); err != nil {
results <- Result{Error: err}
}
}()
ProcessChannel(templates, dir, f, workerCount, results)
}
func shouldSkipDir(dir string) bool {
if dir == "." {
return false
}
if dir == "vendor" || dir == "node_modules" {
return true
}
_, name := path.Split(dir)
// These directories are ignored by the Go tool.
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
return true
}
return false
}
func FindTemplates(srcPath string, output chan<- string) (err error) {
return filepath.WalkDir(srcPath, func(currentPath string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
if info.IsDir() && shouldSkipDir(currentPath) {
return filepath.SkipDir
}
if !info.IsDir() && strings.HasSuffix(currentPath, ".templ") {
output <- currentPath
}
return nil
})
}
func ProcessChannel(templates <-chan string, dir string, f func(fileName string) (error, bool), workerCount int, results chan<- Result) {
defer close(results)
var wg sync.WaitGroup
wg.Add(workerCount)
for i := 0; i < workerCount; i++ {
go func() {
defer wg.Done()
for sourceFileName := range templates {
start := time.Now()
outErr, outChanged := f(sourceFileName)
results <- Result{
FileName: sourceFileName,
Error: outErr,
Duration: time.Since(start),
ChangesMade: outChanged,
}
}
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,19 @@
package processor
import (
"os"
"testing"
)
func TestFindTemplates(t *testing.T) {
t.Run("returns an error if the directory does not exist", func(t *testing.T) {
output := make(chan string)
err := FindTemplates("nonexistent", output)
if err == nil {
t.Fatal("expected error, but got nil")
}
if !os.IsNotExist(err) {
t.Fatalf("expected os.IsNotExist(err) to be true, but got: %v", err)
}
})
}

View File

@@ -0,0 +1,101 @@
package sloghandler
import (
"context"
"io"
"log/slog"
"strings"
"sync"
"github.com/fatih/color"
)
var _ slog.Handler = &Handler{}
type Handler struct {
h slog.Handler
m *sync.Mutex
w io.Writer
}
var levelToIcon = map[slog.Level]string{
slog.LevelDebug: "(✓)",
slog.LevelInfo: "(✓)",
slog.LevelWarn: "(!)",
slog.LevelError: "(✗)",
}
var levelToColor = map[slog.Level]*color.Color{
slog.LevelDebug: color.New(color.FgCyan),
slog.LevelInfo: color.New(color.FgGreen),
slog.LevelWarn: color.New(color.FgYellow),
slog.LevelError: color.New(color.FgRed),
}
func NewHandler(w io.Writer, opts *slog.HandlerOptions) *Handler {
if opts == nil {
opts = &slog.HandlerOptions{}
}
return &Handler{
w: w,
h: slog.NewTextHandler(w, &slog.HandlerOptions{
Level: opts.Level,
AddSource: opts.AddSource,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if opts.ReplaceAttr != nil {
a = opts.ReplaceAttr(groups, a)
}
if a.Key == slog.LevelKey {
level, ok := levelToIcon[a.Value.Any().(slog.Level)]
if !ok {
level = a.Value.Any().(slog.Level).String()
}
a.Value = slog.StringValue(level)
return a
}
if a.Key == slog.TimeKey {
return slog.Attr{}
}
return a
},
}),
m: &sync.Mutex{},
}
}
func (h *Handler) Enabled(ctx context.Context, level slog.Level) bool {
return h.h.Enabled(ctx, level)
}
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
return &Handler{h: h.h.WithAttrs(attrs), w: h.w, m: h.m}
}
func (h *Handler) WithGroup(name string) slog.Handler {
return &Handler{h: h.h.WithGroup(name), w: h.w, m: h.m}
}
var keyValueColor = color.New(color.Faint & color.FgBlack)
func (h *Handler) Handle(ctx context.Context, r slog.Record) (err error) {
var sb strings.Builder
sb.WriteString(levelToColor[r.Level].Sprint(levelToIcon[r.Level]))
sb.WriteString(" ")
sb.WriteString(r.Message)
if r.NumAttrs() != 0 {
sb.WriteString(" [")
r.Attrs(func(a slog.Attr) bool {
sb.WriteString(keyValueColor.Sprintf(" %s=%s", a.Key, a.Value.String()))
return true
})
sb.WriteString(" ]")
}
sb.WriteString("\n")
h.m.Lock()
defer h.m.Unlock()
_, err = io.WriteString(h.w, sb.String())
return err
}

View File

@@ -0,0 +1,3 @@
package cssclasses
const Header = "header"

View File

@@ -0,0 +1,7 @@
module templ/testproject
go 1.23
require github.com/a-h/templ v0.2.513 // indirect
replace github.com/a-h/templ => {moduleRoot}

View File

@@ -0,0 +1,2 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=

View File

@@ -0,0 +1,33 @@
package main
import (
"flag"
"fmt"
"net/http"
"os"
"github.com/a-h/templ"
)
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
func main() {
flag.Parse()
if *flagPort == 0 {
fmt.Println("missing port flag")
os.Exit(1)
}
var count int
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
count++
c := Page(count)
templ.Handler(c).ServeHTTP(w, r)
})
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
if err != nil {
fmt.Printf("Error listening: %v\n", err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,5 @@
package main
templ Remote() {
<p>This is remote content</p>
}

View File

@@ -0,0 +1,40 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package main
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
func Remote() templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<p>This is remote content</p>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,9 @@
package main
templ RemoteInclusionTest() {
@Remote
}
templ Remote2() {
@Remote
}

View File

@@ -0,0 +1,69 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package main
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
func RemoteInclusionTest() templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = Remote.Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
func Remote2() templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var2 := templ.GetChildren(ctx)
if templ_7745c5c3_Var2 == nil {
templ_7745c5c3_Var2 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = Remote.Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,25 @@
package main
import "fmt"
templ Page(count int) {
<!DOCTYPE html>
<html>
<head>
<title>templ test page</title>
</head>
<body>
<h1>Count</h1>
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
<div data-testid="modification">Original</div>
</body>
</html>
}
var nihao = "你好"
type Struct struct {
Count int
}
var s = Struct{}

View File

@@ -0,0 +1,63 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package main
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
import "fmt"
func Page(count int) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/testproject/testdata/templates.templ`, Line: 13, Col: 54}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var nihao = "你好"
type Struct struct {
Count int
}
var s = Struct{}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,70 @@
package testproject
import (
"bytes"
"embed"
"fmt"
"os"
"path/filepath"
"strings"
)
//go:embed testdata/*
var testdata embed.FS
func Create(moduleRoot string) (dir string, err error) {
dir, err = os.MkdirTemp("", "templ_test_*")
if err != nil {
return dir, fmt.Errorf("failed to make test dir: %w", err)
}
files, err := testdata.ReadDir("testdata")
if err != nil {
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
}
for _, file := range files {
if file.IsDir() {
if err = os.MkdirAll(filepath.Join(dir, file.Name()), 0777); err != nil {
return dir, fmt.Errorf("failed to create dir: %w", err)
}
continue
}
src := filepath.Join("testdata", file.Name())
data, err := testdata.ReadFile(src)
if err != nil {
return dir, fmt.Errorf("failed to read file: %w", err)
}
target := filepath.Join(dir, file.Name())
if file.Name() == "go.mod.embed" {
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
target = filepath.Join(dir, "go.mod")
}
err = os.WriteFile(target, data, 0660)
if err != nil {
return dir, fmt.Errorf("failed to copy file: %w", err)
}
}
files, err = testdata.ReadDir("testdata/css-classes")
if err != nil {
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
}
for _, file := range files {
src := filepath.Join("testdata", "css-classes", file.Name())
data, err := testdata.ReadFile(src)
if err != nil {
return dir, fmt.Errorf("failed to read file: %w", err)
}
target := filepath.Join(dir, "css-classes", file.Name())
err = os.WriteFile(target, data, 0660)
if err != nil {
return dir, fmt.Errorf("failed to copy file: %w", err)
}
}
return dir, nil
}
func MustReplaceLine(file string, line int, replacement string) string {
lines := strings.Split(file, "\n")
lines[line-1] = replacement
return strings.Join(lines, "\n")
}

View File

@@ -0,0 +1,64 @@
package visualize
css row() {
display: flex;
}
css column() {
flex: 50%;
overflow-y: scroll;
max-height: 100vh;
}
css code() {
font-family: monospace;
}
templ combine(templFileName string, left, right templ.Component) {
<html>
<head>
<title>{ templFileName }- Source Map Visualisation</title>
<style type="text/css">
.mapped { background-color: green }
.highlighted { background-color: yellow }
</style>
</head>
<body>
<h1>{ templFileName }</h1>
<div class={ templ.Classes(row()) }>
<div class={ templ.Classes(column(), code()) }>
@left
</div>
<div class={ templ.Classes(column(), code()) }>
@right
</div>
</div>
</body>
</html>
}
script highlight(sourceId, targetId string) {
let items = document.getElementsByClassName(sourceId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.add("highlighted");
}
items = document.getElementsByClassName(targetId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.add("highlighted");
}
}
script removeHighlight(sourceId, targetId string) {
let items = document.getElementsByClassName(sourceId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.remove("highlighted");
}
items = document.getElementsByClassName(targetId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.remove("highlighted");
}
}
templ mappedCharacter(s string, sourceID, targetID string) {
<span class={ templ.Classes(templ.Class("mapped"), templ.Class(sourceID), templ.Class(targetID)) } onMouseOver={ highlight(sourceID, targetID) } onMouseOut={ removeHighlight(sourceID, targetID) }>{ s }</span>
}

View File

@@ -0,0 +1,296 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.833
package visualize
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
func row() templ.CSSClass {
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
templ_7745c5c3_CSSBuilder.WriteString(`display:flex;`)
templ_7745c5c3_CSSID := templ.CSSID(`row`, templ_7745c5c3_CSSBuilder.String())
return templ.ComponentCSSClass{
ID: templ_7745c5c3_CSSID,
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
}
}
func column() templ.CSSClass {
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
templ_7745c5c3_CSSBuilder.WriteString(`flex:50%;`)
templ_7745c5c3_CSSBuilder.WriteString(`overflow-y:scroll;`)
templ_7745c5c3_CSSBuilder.WriteString(`max-height:100vh;`)
templ_7745c5c3_CSSID := templ.CSSID(`column`, templ_7745c5c3_CSSBuilder.String())
return templ.ComponentCSSClass{
ID: templ_7745c5c3_CSSID,
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
}
}
func code() templ.CSSClass {
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
templ_7745c5c3_CSSBuilder.WriteString(`font-family:monospace;`)
templ_7745c5c3_CSSID := templ.CSSID(`code`, templ_7745c5c3_CSSBuilder.String())
return templ.ComponentCSSClass{
ID: templ_7745c5c3_CSSID,
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
}
}
func combine(templFileName string, left, right templ.Component) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<html><head><title>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var2 string
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(templFileName)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 20, Col: 25}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "- Source Map Visualisation</title><style type=\"text/css\">\n\t\t\t\t.mapped { background-color: green }\n\t\t\t\t.highlighted { background-color: yellow }\n\t\t\t</style></head><body><h1>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var3 string
templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(templFileName)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 27, Col: 22}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</h1>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var4 = []any{templ.Classes(row())}
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var4...)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "<div class=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var5 string
templ_7745c5c3_Var5, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var4).String())
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var5))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var6 = []any{templ.Classes(column(), code())}
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var6...)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "<div class=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var7 string
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var6).String())
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = left.Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var8 = []any{templ.Classes(column(), code())}
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var8...)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "<div class=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var9 string
templ_7745c5c3_Var9, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var8).String())
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var9))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = right.Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "</div></div></body></html>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
func highlight(sourceId, targetId string) templ.ComponentScript {
return templ.ComponentScript{
Name: `__templ_highlight_ae80`,
Function: `function __templ_highlight_ae80(sourceId, targetId){let items = document.getElementsByClassName(sourceId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.add("highlighted");
}
items = document.getElementsByClassName(targetId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.add("highlighted");
}
}`,
Call: templ.SafeScript(`__templ_highlight_ae80`, sourceId, targetId),
CallInline: templ.SafeScriptInline(`__templ_highlight_ae80`, sourceId, targetId),
}
}
func removeHighlight(sourceId, targetId string) templ.ComponentScript {
return templ.ComponentScript{
Name: `__templ_removeHighlight_58f2`,
Function: `function __templ_removeHighlight_58f2(sourceId, targetId){let items = document.getElementsByClassName(sourceId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.remove("highlighted");
}
items = document.getElementsByClassName(targetId);
for(let i = 0; i < items.length; i ++) {
items[i].classList.remove("highlighted");
}
}`,
Call: templ.SafeScript(`__templ_removeHighlight_58f2`, sourceId, targetId),
CallInline: templ.SafeScriptInline(`__templ_removeHighlight_58f2`, sourceId, targetId),
}
}
func mappedCharacter(s string, sourceID, targetID string) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var10 := templ.GetChildren(ctx)
if templ_7745c5c3_Var10 == nil {
templ_7745c5c3_Var10 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
var templ_7745c5c3_Var11 = []any{templ.Classes(templ.Class("mapped"), templ.Class(sourceID), templ.Class(targetID))}
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var11...)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templ.RenderScriptItems(ctx, templ_7745c5c3_Buffer, highlight(sourceID, targetID), removeHighlight(sourceID, targetID))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "<span class=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var12 string
templ_7745c5c3_Var12, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var11).String())
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var12))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "\" onMouseOver=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var13 templ.ComponentScript = highlight(sourceID, targetID)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ_7745c5c3_Var13.Call)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "\" onMouseOut=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var14 templ.ComponentScript = removeHighlight(sourceID, targetID)
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ_7745c5c3_Var14.Call)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var15 string
templ_7745c5c3_Var15, templ_7745c5c3_Err = templ.JoinStringErrs(s)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 63, Col: 200}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var15))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "</span>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate

View File

@@ -0,0 +1,87 @@
package visualize
import (
"context"
"fmt"
"html"
"io"
"strconv"
"strings"
"github.com/a-h/templ"
"github.com/a-h/templ/parser/v2"
)
func HTML(templFileName string, templContents, goContents string, sourceMap *parser.SourceMap) templ.Component {
tl := templLines{contents: string(templContents), sourceMap: sourceMap}
gl := goLines{contents: string(goContents), sourceMap: sourceMap}
return combine(templFileName, tl, gl)
}
type templLines struct {
contents string
sourceMap *parser.SourceMap
}
func (tl templLines) Render(ctx context.Context, w io.Writer) (err error) {
templLines := strings.Split(tl.contents, "\n")
for lineIndex, line := range templLines {
if _, err = w.Write([]byte("<span>" + strconv.Itoa(lineIndex) + "&nbsp;</span>\n")); err != nil {
return
}
for colIndex, c := range line {
if tgt, ok := tl.sourceMap.TargetPositionFromSource(uint32(lineIndex), uint32(colIndex)); ok {
sourceID := fmt.Sprintf("src_%d_%d", lineIndex, colIndex)
targetID := fmt.Sprintf("tgt_%d_%d", tgt.Line, tgt.Col)
if err := mappedCharacter(string(c), sourceID, targetID).Render(ctx, w); err != nil {
return err
}
} else {
s := html.EscapeString(string(c))
s = strings.ReplaceAll(s, "\t", "&nbsp;")
s = strings.ReplaceAll(s, " ", "&nbsp;")
if _, err := w.Write([]byte(s)); err != nil {
return err
}
}
}
if _, err = w.Write([]byte("\n<br/>\n")); err != nil {
return
}
}
return nil
}
type goLines struct {
contents string
sourceMap *parser.SourceMap
}
func (gl goLines) Render(ctx context.Context, w io.Writer) (err error) {
templLines := strings.Split(gl.contents, "\n")
for lineIndex, line := range templLines {
if _, err = w.Write([]byte("<span>" + strconv.Itoa(lineIndex) + "&nbsp;</span>\n")); err != nil {
return
}
for colIndex, c := range line {
if src, ok := gl.sourceMap.SourcePositionFromTarget(uint32(lineIndex), uint32(colIndex)); ok {
sourceID := fmt.Sprintf("src_%d_%d", src.Line, src.Col)
targetID := fmt.Sprintf("tgt_%d_%d", lineIndex, colIndex)
if err := mappedCharacter(string(c), sourceID, targetID).Render(ctx, w); err != nil {
return err
}
} else {
s := html.EscapeString(string(c))
s = strings.ReplaceAll(s, "\t", "&nbsp;")
s = strings.ReplaceAll(s, " ", "&nbsp;")
if _, err := w.Write([]byte(s)); err != nil {
return err
}
}
}
if _, err = w.Write([]byte("\n<br/>\n")); err != nil {
return
}
}
return nil
}