Changed: DB Params
This commit is contained in:
166
templ/cmd/templ/fmtcmd/main.go
Normal file
166
templ/cmd/templ/fmtcmd/main.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package fmtcmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/imports"
|
||||
"github.com/a-h/templ/cmd/templ/processor"
|
||||
parser "github.com/a-h/templ/parser/v2"
|
||||
"github.com/natefinch/atomic"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
FailIfChanged bool
|
||||
ToStdout bool
|
||||
StdinFilepath string
|
||||
Files []string
|
||||
WorkerCount int
|
||||
}
|
||||
|
||||
func Run(log *slog.Logger, stdin io.Reader, stdout io.Writer, args Arguments) (err error) {
|
||||
// If no files are provided, read from stdin and write to stdout.
|
||||
if len(args.Files) == 0 {
|
||||
out, _ := format(writeToWriter(stdout), readFromReader(stdin, args.StdinFilepath), true)
|
||||
return out
|
||||
}
|
||||
process := func(fileName string) (error, bool) {
|
||||
read := readFromFile(fileName)
|
||||
write := writeToFile
|
||||
if args.ToStdout {
|
||||
write = writeToWriter(stdout)
|
||||
}
|
||||
writeIfUnchanged := args.ToStdout
|
||||
return format(write, read, writeIfUnchanged)
|
||||
}
|
||||
dir := args.Files[0]
|
||||
return NewFormatter(log, dir, process, args.WorkerCount, args.FailIfChanged).Run()
|
||||
}
|
||||
|
||||
type Formatter struct {
|
||||
Log *slog.Logger
|
||||
Dir string
|
||||
Process func(fileName string) (error, bool)
|
||||
WorkerCount int
|
||||
FailIfChange bool
|
||||
}
|
||||
|
||||
func NewFormatter(log *slog.Logger, dir string, process func(fileName string) (error, bool), workerCount int, failIfChange bool) *Formatter {
|
||||
f := &Formatter{
|
||||
Log: log,
|
||||
Dir: dir,
|
||||
Process: process,
|
||||
WorkerCount: workerCount,
|
||||
FailIfChange: failIfChange,
|
||||
}
|
||||
if f.WorkerCount == 0 {
|
||||
f.WorkerCount = runtime.NumCPU()
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Formatter) Run() (err error) {
|
||||
changesMade := 0
|
||||
start := time.Now()
|
||||
results := make(chan processor.Result)
|
||||
f.Log.Debug("Walking directory", slog.String("path", f.Dir))
|
||||
go processor.Process(f.Dir, f.Process, f.WorkerCount, results)
|
||||
var successCount, errorCount int
|
||||
for r := range results {
|
||||
if r.ChangesMade {
|
||||
changesMade += 1
|
||||
}
|
||||
if r.Error != nil {
|
||||
f.Log.Error(r.FileName, slog.Any("error", r.Error))
|
||||
errorCount++
|
||||
continue
|
||||
}
|
||||
f.Log.Debug(r.FileName, slog.Duration("duration", r.Duration))
|
||||
successCount++
|
||||
}
|
||||
|
||||
if f.FailIfChange && changesMade > 0 {
|
||||
f.Log.Error("Templates were valid but not properly formatted", slog.Int("count", successCount+errorCount), slog.Int("changed", changesMade), slog.Int("errors", errorCount), slog.Duration("duration", time.Since(start)))
|
||||
return fmt.Errorf("templates were not formatted properly")
|
||||
}
|
||||
|
||||
f.Log.Info("Format Complete", slog.Int("count", successCount+errorCount), slog.Int("errors", errorCount), slog.Int("changed", changesMade), slog.Duration("duration", time.Since(start)))
|
||||
|
||||
if errorCount > 0 {
|
||||
return fmt.Errorf("formatting failed")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type reader func() (fileName, src string, err error)
|
||||
|
||||
func readFromReader(r io.Reader, stdinFilepath string) func() (fileName, src string, err error) {
|
||||
return func() (fileName, src string, err error) {
|
||||
b, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to read stdin: %w", err)
|
||||
}
|
||||
return stdinFilepath, string(b), nil
|
||||
}
|
||||
}
|
||||
|
||||
func readFromFile(name string) reader {
|
||||
return func() (fileName, src string, err error) {
|
||||
b, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to read file %q: %w", fileName, err)
|
||||
}
|
||||
return name, string(b), nil
|
||||
}
|
||||
}
|
||||
|
||||
type writer func(fileName, tgt string) error
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
func writeToWriter(w io.Writer) func(fileName, tgt string) error {
|
||||
return func(fileName, tgt string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
_, err := w.Write([]byte(tgt))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func writeToFile(fileName, tgt string) error {
|
||||
return atomic.WriteFile(fileName, bytes.NewBufferString(tgt))
|
||||
}
|
||||
|
||||
func format(write writer, read reader, writeIfUnchanged bool) (err error, fileChanged bool) {
|
||||
fileName, src, err := read()
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
t, err := parser.ParseString(src)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
t.Filepath = fileName
|
||||
t, err = imports.Process(t)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
w := new(bytes.Buffer)
|
||||
if err = t.Write(w); err != nil {
|
||||
return fmt.Errorf("formatting error: %w", err), false
|
||||
}
|
||||
|
||||
fileChanged = (src != w.String())
|
||||
|
||||
if !writeIfUnchanged && !fileChanged {
|
||||
return nil, fileChanged
|
||||
}
|
||||
return write(fileName, w.String()), fileChanged
|
||||
}
|
163
templ/cmd/templ/fmtcmd/main_test.go
Normal file
163
templ/cmd/templ/fmtcmd/main_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package fmtcmd
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/tools/txtar"
|
||||
)
|
||||
|
||||
//go:embed testdata.txtar
|
||||
var testDataTxTar []byte
|
||||
|
||||
type testProject struct {
|
||||
dir string
|
||||
cleanup func()
|
||||
testFiles map[string]testFile
|
||||
}
|
||||
|
||||
type testFile struct {
|
||||
name string
|
||||
input, expected string
|
||||
}
|
||||
|
||||
func setupProjectDir() (tp testProject, err error) {
|
||||
tp.dir, err = os.MkdirTemp("", "fmtcmd_test_*")
|
||||
if err != nil {
|
||||
return tp, fmt.Errorf("failed to make test dir: %w", err)
|
||||
}
|
||||
tp.testFiles = make(map[string]testFile)
|
||||
testData := txtar.Parse(testDataTxTar)
|
||||
for i := 0; i < len(testData.Files); i += 2 {
|
||||
file := testData.Files[i]
|
||||
err = os.WriteFile(filepath.Join(tp.dir, file.Name), file.Data, 0660)
|
||||
if err != nil {
|
||||
return tp, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
tp.testFiles[file.Name] = testFile{
|
||||
name: filepath.Join(tp.dir, file.Name),
|
||||
input: string(file.Data),
|
||||
expected: string(testData.Files[i+1].Data),
|
||||
}
|
||||
}
|
||||
tp.cleanup = func() {
|
||||
os.RemoveAll(tp.dir)
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func TestFormat(t *testing.T) {
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
t.Run("can format a single file from stdin to stdout", func(t *testing.T) {
|
||||
tp, err := setupProjectDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup project dir: %v", err)
|
||||
}
|
||||
defer tp.cleanup()
|
||||
stdin := strings.NewReader(tp.testFiles["a.templ"].input)
|
||||
stdout := new(strings.Builder)
|
||||
if err = Run(log, stdin, stdout, Arguments{
|
||||
ToStdout: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to run format command: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, stdout.String()); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
t.Run("can process a single file to stdout", func(t *testing.T) {
|
||||
tp, err := setupProjectDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup project dir: %v", err)
|
||||
}
|
||||
defer tp.cleanup()
|
||||
stdout := new(strings.Builder)
|
||||
if err = Run(log, nil, stdout, Arguments{
|
||||
ToStdout: true,
|
||||
Files: []string{
|
||||
tp.testFiles["a.templ"].name,
|
||||
},
|
||||
FailIfChanged: false,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to run format command: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, stdout.String()); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
t.Run("can process a single file in place", func(t *testing.T) {
|
||||
tp, err := setupProjectDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup project dir: %v", err)
|
||||
}
|
||||
defer tp.cleanup()
|
||||
if err = Run(log, nil, nil, Arguments{
|
||||
Files: []string{
|
||||
tp.testFiles["a.templ"].name,
|
||||
},
|
||||
FailIfChanged: false,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to run format command: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(tp.testFiles["a.templ"].name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, string(data)); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails when fail flag used and change occurs", func(t *testing.T) {
|
||||
tp, err := setupProjectDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup project dir: %v", err)
|
||||
}
|
||||
defer tp.cleanup()
|
||||
if err = Run(log, nil, nil, Arguments{
|
||||
Files: []string{
|
||||
tp.testFiles["a.templ"].name,
|
||||
},
|
||||
FailIfChanged: true,
|
||||
}); err == nil {
|
||||
t.Fatal("command should have exited with an error and did not")
|
||||
}
|
||||
data, err := os.ReadFile(tp.testFiles["a.templ"].name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tp.testFiles["a.templ"].expected, string(data)); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("passes when fail flag used and no change occurs", func(t *testing.T) {
|
||||
tp, err := setupProjectDir()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup project dir: %v", err)
|
||||
}
|
||||
defer tp.cleanup()
|
||||
if err = Run(log, nil, nil, Arguments{
|
||||
Files: []string{
|
||||
tp.testFiles["c.templ"].name,
|
||||
},
|
||||
FailIfChanged: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to run format command: %v", err)
|
||||
}
|
||||
data, err := os.ReadFile(tp.testFiles["c.templ"].name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tp.testFiles["c.templ"].expected, string(data)); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
54
templ/cmd/templ/fmtcmd/testdata.txtar
Normal file
54
templ/cmd/templ/fmtcmd/testdata.txtar
Normal file
@@ -0,0 +1,54 @@
|
||||
-- a.templ --
|
||||
package test
|
||||
|
||||
templ a() {
|
||||
<div><p class={templ.Class("mapped")}>A
|
||||
</p></div>
|
||||
}
|
||||
-- a.templ --
|
||||
package test
|
||||
|
||||
templ a() {
|
||||
<div>
|
||||
<p class={ templ.Class("mapped") }>
|
||||
A
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
-- b.templ --
|
||||
package test
|
||||
|
||||
templ b() {
|
||||
<div><p>B
|
||||
</p></div>
|
||||
}
|
||||
-- b.templ --
|
||||
package test
|
||||
|
||||
templ b() {
|
||||
<div>
|
||||
<p>
|
||||
B
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
-- c.templ --
|
||||
package test
|
||||
|
||||
templ c() {
|
||||
<div>
|
||||
<p>
|
||||
C
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
-- c.templ --
|
||||
package test
|
||||
|
||||
templ c() {
|
||||
<div>
|
||||
<p>
|
||||
C
|
||||
</p>
|
||||
</div>
|
||||
}
|
403
templ/cmd/templ/generatecmd/cmd.go
Normal file
403
templ/cmd/templ/generatecmd/cmd.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/proxy"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/run"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/watcher"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/browser"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
const defaultWatchPattern = `(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)`
|
||||
|
||||
func NewGenerate(log *slog.Logger, args Arguments) (g *Generate, err error) {
|
||||
g = &Generate{
|
||||
Log: log,
|
||||
Args: &args,
|
||||
}
|
||||
if g.Args.WorkerCount == 0 {
|
||||
g.Args.WorkerCount = runtime.NumCPU()
|
||||
}
|
||||
if g.Args.WatchPattern == "" {
|
||||
g.Args.WatchPattern = defaultWatchPattern
|
||||
}
|
||||
g.WatchPattern, err = regexp.Compile(g.Args.WatchPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile watch pattern %q: %w", g.Args.WatchPattern, err)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
type Generate struct {
|
||||
Log *slog.Logger
|
||||
Args *Arguments
|
||||
WatchPattern *regexp.Regexp
|
||||
}
|
||||
|
||||
type GenerationEvent struct {
|
||||
Event fsnotify.Event
|
||||
Updated bool
|
||||
GoUpdated bool
|
||||
TextUpdated bool
|
||||
}
|
||||
|
||||
func (cmd Generate) Run(ctx context.Context) (err error) {
|
||||
if cmd.Args.NotifyProxy {
|
||||
return proxy.NotifyProxy(cmd.Args.ProxyBind, cmd.Args.ProxyPort)
|
||||
}
|
||||
if cmd.Args.Watch && cmd.Args.FileName != "" {
|
||||
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
|
||||
}
|
||||
writingToWriter := cmd.Args.FileWriter != nil
|
||||
if cmd.Args.FileName == "" && writingToWriter {
|
||||
return fmt.Errorf("only a single file can be output to stdout, add the -f flag to specify the file to generate code for")
|
||||
}
|
||||
// Default to writing to files.
|
||||
if cmd.Args.FileWriter == nil {
|
||||
cmd.Args.FileWriter = FileWriter
|
||||
}
|
||||
if cmd.Args.PPROFPort > 0 {
|
||||
go func() {
|
||||
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", cmd.Args.PPROFPort), nil)
|
||||
}()
|
||||
}
|
||||
|
||||
// Use absolute path.
|
||||
if !path.IsAbs(cmd.Args.Path) {
|
||||
cmd.Args.Path, err = filepath.Abs(cmd.Args.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Configure generator.
|
||||
var opts []generator.GenerateOpt
|
||||
if cmd.Args.IncludeVersion {
|
||||
opts = append(opts, generator.WithVersion(templ.Version()))
|
||||
}
|
||||
if cmd.Args.IncludeTimestamp {
|
||||
opts = append(opts, generator.WithTimestamp(time.Now()))
|
||||
}
|
||||
|
||||
// Check the version of the templ module.
|
||||
if err := modcheck.Check(cmd.Args.Path); err != nil {
|
||||
cmd.Log.Warn("templ version check: " + err.Error())
|
||||
}
|
||||
|
||||
fseh := NewFSEventHandler(
|
||||
cmd.Log,
|
||||
cmd.Args.Path,
|
||||
cmd.Args.Watch,
|
||||
opts,
|
||||
cmd.Args.GenerateSourceMapVisualisations,
|
||||
cmd.Args.KeepOrphanedFiles,
|
||||
cmd.Args.FileWriter,
|
||||
cmd.Args.Lazy,
|
||||
)
|
||||
|
||||
// If we're processing a single file, don't bother setting up the channels/multithreaing.
|
||||
if cmd.Args.FileName != "" {
|
||||
_, err = fseh.HandleEvent(ctx, fsnotify.Event{
|
||||
Name: cmd.Args.FileName,
|
||||
Op: fsnotify.Create,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Start timer.
|
||||
start := time.Now()
|
||||
|
||||
// Create channels:
|
||||
// For the initial filesystem walk and subsequent (optional) fsnotify events.
|
||||
events := make(chan fsnotify.Event)
|
||||
// Count of events currently being processed by the event handler.
|
||||
var eventsWG sync.WaitGroup
|
||||
// Used to check that the event handler has completed.
|
||||
var eventHandlerWG sync.WaitGroup
|
||||
// For errs from the watcher.
|
||||
errs := make(chan error)
|
||||
// Tracks whether errors occurred during the generation process.
|
||||
var errorCount atomic.Int64
|
||||
// For triggering actions after generation has completed.
|
||||
postGeneration := make(chan *GenerationEvent, 256)
|
||||
// Used to check that the post-generation handler has completed.
|
||||
var postGenerationWG sync.WaitGroup
|
||||
var postGenerationEventsWG sync.WaitGroup
|
||||
|
||||
// Waitgroup for the push process.
|
||||
var pushHandlerWG sync.WaitGroup
|
||||
|
||||
// Start process to push events into the channel.
|
||||
pushHandlerWG.Add(1)
|
||||
go func() {
|
||||
defer pushHandlerWG.Done()
|
||||
defer close(events)
|
||||
cmd.Log.Debug(
|
||||
"Walking directory",
|
||||
slog.String("path", cmd.Args.Path),
|
||||
slog.Bool("devMode", cmd.Args.Watch),
|
||||
)
|
||||
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil {
|
||||
cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
|
||||
return
|
||||
}
|
||||
if !cmd.Args.Watch {
|
||||
cmd.Log.Debug("Dev mode not enabled, process can finish early")
|
||||
return
|
||||
}
|
||||
cmd.Log.Info("Watching files")
|
||||
rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.WatchPattern, events, errs)
|
||||
if err != nil {
|
||||
cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)}
|
||||
return
|
||||
}
|
||||
cmd.Log.Debug("Waiting for context to be cancelled to stop watching files")
|
||||
<-ctx.Done()
|
||||
cmd.Log.Debug("Context cancelled, closing watcher")
|
||||
if err := rw.Close(); err != nil {
|
||||
cmd.Log.Error("Failed to close watcher", slog.Any("error", err))
|
||||
}
|
||||
cmd.Log.Debug("Waiting for events to be processed")
|
||||
eventsWG.Wait()
|
||||
cmd.Log.Debug(
|
||||
"All pending events processed, waiting for pending post-generation events to complete",
|
||||
)
|
||||
postGenerationEventsWG.Wait()
|
||||
cmd.Log.Debug(
|
||||
"All post-generation events processed, deleting watch mode text files",
|
||||
slog.Int64("errorCount", errorCount.Load()),
|
||||
)
|
||||
|
||||
fileEvents := make(chan fsnotify.Event)
|
||||
go func() {
|
||||
if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, fileEvents); err != nil {
|
||||
cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err))
|
||||
errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)}
|
||||
return
|
||||
}
|
||||
close(fileEvents)
|
||||
}()
|
||||
for event := range fileEvents {
|
||||
if strings.HasSuffix(event.Name, "_templ.txt") {
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
cmd.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start process to handle events.
|
||||
eventHandlerWG.Add(1)
|
||||
sem := make(chan struct{}, cmd.Args.WorkerCount)
|
||||
go func() {
|
||||
defer eventHandlerWG.Done()
|
||||
defer close(postGeneration)
|
||||
cmd.Log.Debug("Starting event handler")
|
||||
for event := range events {
|
||||
eventsWG.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(event fsnotify.Event) {
|
||||
cmd.Log.Debug("Processing file", slog.String("file", event.Name))
|
||||
defer eventsWG.Done()
|
||||
defer func() { <-sem }()
|
||||
r, err := fseh.HandleEvent(ctx, event)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
if !(r.GoUpdated || r.TextUpdated) {
|
||||
cmd.Log.Debug("File not updated", slog.String("file", event.Name))
|
||||
return
|
||||
}
|
||||
e := &GenerationEvent{
|
||||
Event: event,
|
||||
Updated: r.Updated,
|
||||
GoUpdated: r.GoUpdated,
|
||||
TextUpdated: r.TextUpdated,
|
||||
}
|
||||
cmd.Log.Debug("File updated", slog.String("file", event.Name))
|
||||
postGeneration <- e
|
||||
}(event)
|
||||
}
|
||||
// Wait for all events to be processed before closing.
|
||||
eventsWG.Wait()
|
||||
}()
|
||||
|
||||
// Start process to handle post-generation events.
|
||||
var updates int
|
||||
postGenerationWG.Add(1)
|
||||
var firstPostGenerationExecuted bool
|
||||
go func() {
|
||||
defer close(errs)
|
||||
defer postGenerationWG.Done()
|
||||
cmd.Log.Debug("Starting post-generation handler")
|
||||
timeout := time.NewTimer(time.Hour * 24 * 365)
|
||||
var goUpdated, textUpdated bool
|
||||
var p *proxy.Handler
|
||||
for {
|
||||
select {
|
||||
case ge := <-postGeneration:
|
||||
if ge == nil {
|
||||
cmd.Log.Debug("Post-generation event channel closed, exiting")
|
||||
return
|
||||
}
|
||||
goUpdated = goUpdated || ge.GoUpdated
|
||||
textUpdated = textUpdated || ge.TextUpdated
|
||||
if goUpdated || textUpdated {
|
||||
updates++
|
||||
}
|
||||
// Reset timer.
|
||||
if !timeout.Stop() {
|
||||
<-timeout.C
|
||||
}
|
||||
timeout.Reset(time.Millisecond * 100)
|
||||
case <-timeout.C:
|
||||
if !goUpdated && !textUpdated {
|
||||
// Nothing to process, reset timer and wait again.
|
||||
timeout.Reset(time.Hour * 24 * 365)
|
||||
break
|
||||
}
|
||||
postGenerationEventsWG.Add(1)
|
||||
if cmd.Args.Command != "" && goUpdated {
|
||||
cmd.Log.Debug("Executing command", slog.String("command", cmd.Args.Command))
|
||||
if cmd.Args.Watch {
|
||||
os.Setenv("TEMPL_DEV_MODE", "true")
|
||||
}
|
||||
if _, err := run.Run(ctx, cmd.Args.Path, cmd.Args.Command); err != nil {
|
||||
cmd.Log.Error("Error executing command", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
if !firstPostGenerationExecuted {
|
||||
cmd.Log.Debug("First post-generation event received, starting proxy")
|
||||
firstPostGenerationExecuted = true
|
||||
p, err = cmd.StartProxy(ctx)
|
||||
if err != nil {
|
||||
cmd.Log.Error("Failed to start proxy", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
// Send server-sent event.
|
||||
if p != nil && (textUpdated || goUpdated) {
|
||||
cmd.Log.Debug("Sending reload event")
|
||||
p.SendSSE("message", "reload")
|
||||
}
|
||||
postGenerationEventsWG.Done()
|
||||
// Reset timer.
|
||||
timeout.Reset(time.Millisecond * 100)
|
||||
textUpdated = false
|
||||
goUpdated = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read errors.
|
||||
for err := range errs {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, FatalError{}) {
|
||||
cmd.Log.Debug("Fatal error, exiting")
|
||||
return err
|
||||
}
|
||||
cmd.Log.Error("Error", slog.Any("error", err))
|
||||
errorCount.Add(1)
|
||||
}
|
||||
|
||||
// Wait for everything to complete.
|
||||
cmd.Log.Debug("Waiting for push handler to complete")
|
||||
pushHandlerWG.Wait()
|
||||
cmd.Log.Debug("Waiting for event handler to complete")
|
||||
eventHandlerWG.Wait()
|
||||
cmd.Log.Debug("Waiting for post-generation handler to complete")
|
||||
postGenerationWG.Wait()
|
||||
if cmd.Args.Command != "" {
|
||||
cmd.Log.Debug("Killing command", slog.String("command", cmd.Args.Command))
|
||||
if err := run.KillAll(); err != nil {
|
||||
cmd.Log.Error("Error killing command", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors after everything has completed.
|
||||
if errorCount.Load() > 0 {
|
||||
return fmt.Errorf("generation completed with %d errors", errorCount.Load())
|
||||
}
|
||||
|
||||
cmd.Log.Info(
|
||||
"Complete",
|
||||
slog.Int("updates", updates),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err error) {
|
||||
if cmd.Args.Proxy == "" {
|
||||
cmd.Log.Debug("No proxy URL specified, not starting proxy")
|
||||
return nil, nil
|
||||
}
|
||||
var target *url.URL
|
||||
target, err = url.Parse(cmd.Args.Proxy)
|
||||
if err != nil {
|
||||
return nil, FatalError{Err: fmt.Errorf("failed to parse proxy URL: %w", err)}
|
||||
}
|
||||
if cmd.Args.ProxyPort == 0 {
|
||||
cmd.Args.ProxyPort = 7331
|
||||
}
|
||||
if cmd.Args.ProxyBind == "" {
|
||||
cmd.Args.ProxyBind = "127.0.0.1"
|
||||
}
|
||||
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
|
||||
go func() {
|
||||
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
|
||||
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
|
||||
cmd.Log.Error("Proxy failed", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
if !cmd.Args.OpenBrowser {
|
||||
cmd.Log.Debug("Not opening browser")
|
||||
return p, nil
|
||||
}
|
||||
go func() {
|
||||
cmd.Log.Debug("Waiting for proxy to be ready", slog.String("url", p.URL))
|
||||
backoff := backoff.NewExponentialBackOff()
|
||||
backoff.InitialInterval = time.Second
|
||||
var client http.Client
|
||||
client.Timeout = 1 * time.Second
|
||||
for {
|
||||
if _, err := client.Get(p.URL); err == nil {
|
||||
break
|
||||
}
|
||||
d := backoff.NextBackOff()
|
||||
cmd.Log.Debug(
|
||||
"Proxy not ready, retrying",
|
||||
slog.String("url", p.URL),
|
||||
slog.Any("backoff", d),
|
||||
)
|
||||
time.Sleep(d)
|
||||
}
|
||||
if err := browser.OpenURL(p.URL); err != nil {
|
||||
cmd.Log.Error("Failed to open browser", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
return p, nil
|
||||
}
|
366
templ/cmd/templ/generatecmd/eventhandler.go
Normal file
366
templ/cmd/templ/generatecmd/eventhandler.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/visualize"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
type FileWriterFunc func(name string, contents []byte) error
|
||||
|
||||
func FileWriter(fileName string, contents []byte) error {
|
||||
return os.WriteFile(fileName, contents, 0o644)
|
||||
}
|
||||
|
||||
func WriterFileWriter(w io.Writer) FileWriterFunc {
|
||||
return func(_ string, contents []byte) error {
|
||||
_, err := w.Write(contents)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func NewFSEventHandler(
|
||||
log *slog.Logger,
|
||||
dir string,
|
||||
devMode bool,
|
||||
genOpts []generator.GenerateOpt,
|
||||
genSourceMapVis bool,
|
||||
keepOrphanedFiles bool,
|
||||
fileWriter FileWriterFunc,
|
||||
lazy bool,
|
||||
) *FSEventHandler {
|
||||
if !path.IsAbs(dir) {
|
||||
dir, _ = filepath.Abs(dir)
|
||||
}
|
||||
fseh := &FSEventHandler{
|
||||
Log: log,
|
||||
dir: dir,
|
||||
fileNameToLastModTime: make(map[string]time.Time),
|
||||
fileNameToLastModTimeMutex: &sync.Mutex{},
|
||||
fileNameToError: make(map[string]struct{}),
|
||||
fileNameToErrorMutex: &sync.Mutex{},
|
||||
fileNameToOutput: make(map[string]generator.GeneratorOutput),
|
||||
fileNameToOutputMutex: &sync.Mutex{},
|
||||
devMode: devMode,
|
||||
hashes: make(map[string][sha256.Size]byte),
|
||||
hashesMutex: &sync.Mutex{},
|
||||
genOpts: genOpts,
|
||||
genSourceMapVis: genSourceMapVis,
|
||||
keepOrphanedFiles: keepOrphanedFiles,
|
||||
writer: fileWriter,
|
||||
lazy: lazy,
|
||||
}
|
||||
return fseh
|
||||
}
|
||||
|
||||
type FSEventHandler struct {
|
||||
Log *slog.Logger
|
||||
// dir is the root directory being processed.
|
||||
dir string
|
||||
fileNameToLastModTime map[string]time.Time
|
||||
fileNameToLastModTimeMutex *sync.Mutex
|
||||
fileNameToError map[string]struct{}
|
||||
fileNameToErrorMutex *sync.Mutex
|
||||
fileNameToOutput map[string]generator.GeneratorOutput
|
||||
fileNameToOutputMutex *sync.Mutex
|
||||
devMode bool
|
||||
hashes map[string][sha256.Size]byte
|
||||
hashesMutex *sync.Mutex
|
||||
genOpts []generator.GenerateOpt
|
||||
genSourceMapVis bool
|
||||
Errors []error
|
||||
keepOrphanedFiles bool
|
||||
writer func(string, []byte) error
|
||||
lazy bool
|
||||
}
|
||||
|
||||
type GenerateResult struct {
|
||||
// Updated indicates that the file was updated.
|
||||
Updated bool
|
||||
// GoUpdated indicates that Go expressions were updated.
|
||||
GoUpdated bool
|
||||
// TextUpdated indicates that text literals were updated.
|
||||
TextUpdated bool
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) HandleEvent(ctx context.Context, event fsnotify.Event) (result GenerateResult, err error) {
|
||||
// Handle _templ.go files.
|
||||
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.go") {
|
||||
_, err = os.Stat(strings.TrimSuffix(event.Name, "_templ.go") + ".templ")
|
||||
if !os.IsNotExist(err) {
|
||||
return GenerateResult{}, err
|
||||
}
|
||||
// File is orphaned.
|
||||
if h.keepOrphanedFiles {
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
h.Log.Debug("Deleting orphaned Go file", slog.String("file", event.Name))
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
h.Log.Warn("Failed to remove orphaned file", slog.Any("error", err))
|
||||
}
|
||||
return GenerateResult{Updated: true, GoUpdated: true, TextUpdated: false}, nil
|
||||
}
|
||||
// Handle _templ.txt files.
|
||||
if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.txt") {
|
||||
if h.devMode {
|
||||
// Don't delete the file in dev mode, ignore changes to it, since the .templ file
|
||||
// must have been updated in order to trigger a change in the _templ.txt file.
|
||||
return GenerateResult{Updated: false, GoUpdated: false, TextUpdated: false}, nil
|
||||
}
|
||||
h.Log.Debug("Deleting watch mode file", slog.String("file", event.Name))
|
||||
if err = os.Remove(event.Name); err != nil {
|
||||
h.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// If the file hasn't been updated since the last time we processed it, ignore it.
|
||||
lastModTime, updatedModTime := h.UpsertLastModTime(event.Name)
|
||||
if !updatedModTime {
|
||||
h.Log.Debug("Skipping file because it wasn't updated", slog.String("file", event.Name))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// Process anything that isn't a templ file.
|
||||
if !strings.HasSuffix(event.Name, ".templ") {
|
||||
// If it's a Go file, mark it as updated.
|
||||
if strings.HasSuffix(event.Name, ".go") {
|
||||
result.GoUpdated = true
|
||||
}
|
||||
result.Updated = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Handle templ files.
|
||||
|
||||
// If the go file is newer than the templ file, skip generation, because it's up-to-date.
|
||||
if h.lazy && goFileIsUpToDate(event.Name, lastModTime) {
|
||||
h.Log.Debug("Skipping file because the Go file is up-to-date", slog.String("file", event.Name))
|
||||
return GenerateResult{}, nil
|
||||
}
|
||||
|
||||
// Start a processor.
|
||||
start := time.Now()
|
||||
var diag []parser.Diagnostic
|
||||
result, diag, err = h.generate(ctx, event.Name)
|
||||
if err != nil {
|
||||
h.SetError(event.Name, true)
|
||||
return result, fmt.Errorf("failed to generate code for %q: %w", event.Name, err)
|
||||
}
|
||||
if len(diag) > 0 {
|
||||
for _, d := range diag {
|
||||
h.Log.Warn(d.Message,
|
||||
slog.String("from", fmt.Sprintf("%d:%d", d.Range.From.Line, d.Range.From.Col)),
|
||||
slog.String("to", fmt.Sprintf("%d:%d", d.Range.To.Line, d.Range.To.Col)),
|
||||
)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
if errorCleared, errorCount := h.SetError(event.Name, false); errorCleared {
|
||||
h.Log.Info("Error cleared", slog.String("file", event.Name), slog.Int("errors", errorCount))
|
||||
}
|
||||
h.Log.Debug("Generated code", slog.String("file", event.Name), slog.Duration("in", time.Since(start)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func goFileIsUpToDate(templFileName string, templFileLastMod time.Time) (upToDate bool) {
|
||||
goFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ.go"
|
||||
goFileInfo, err := os.Stat(goFileName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return goFileInfo.ModTime().After(templFileLastMod)
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) SetError(fileName string, hasError bool) (previouslyHadError bool, errorCount int) {
|
||||
h.fileNameToErrorMutex.Lock()
|
||||
defer h.fileNameToErrorMutex.Unlock()
|
||||
_, previouslyHadError = h.fileNameToError[fileName]
|
||||
delete(h.fileNameToError, fileName)
|
||||
if hasError {
|
||||
h.fileNameToError[fileName] = struct{}{}
|
||||
}
|
||||
return previouslyHadError, len(h.fileNameToError)
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) UpsertLastModTime(fileName string) (modTime time.Time, updated bool) {
|
||||
fileInfo, err := os.Stat(fileName)
|
||||
if err != nil {
|
||||
return modTime, false
|
||||
}
|
||||
h.fileNameToLastModTimeMutex.Lock()
|
||||
defer h.fileNameToLastModTimeMutex.Unlock()
|
||||
previousModTime := h.fileNameToLastModTime[fileName]
|
||||
currentModTime := fileInfo.ModTime()
|
||||
if !currentModTime.After(previousModTime) {
|
||||
return currentModTime, false
|
||||
}
|
||||
h.fileNameToLastModTime[fileName] = currentModTime
|
||||
return currentModTime, true
|
||||
}
|
||||
|
||||
func (h *FSEventHandler) UpsertHash(fileName string, hash [sha256.Size]byte) (updated bool) {
|
||||
h.hashesMutex.Lock()
|
||||
defer h.hashesMutex.Unlock()
|
||||
lastHash := h.hashes[fileName]
|
||||
if lastHash == hash {
|
||||
return false
|
||||
}
|
||||
h.hashes[fileName] = hash
|
||||
return true
|
||||
}
|
||||
|
||||
// generate Go code for a single template.
|
||||
// If a basePath is provided, the filename included in error messages is relative to it.
|
||||
func (h *FSEventHandler) generate(ctx context.Context, fileName string) (result GenerateResult, diagnostics []parser.Diagnostic, err error) {
|
||||
t, err := parser.Parse(fileName)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s parsing error: %w", fileName, err)
|
||||
}
|
||||
targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go"
|
||||
|
||||
// Only use relative filenames to the basepath for filenames in runtime error messages.
|
||||
absFilePath, err := filepath.Abs(fileName)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("failed to get absolute path for %q: %w", fileName, err)
|
||||
}
|
||||
relFilePath, err := filepath.Rel(h.dir, absFilePath)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("failed to get relative path for %q: %w", fileName, err)
|
||||
}
|
||||
// Convert Windows file paths to Unix-style for consistency.
|
||||
relFilePath = filepath.ToSlash(relFilePath)
|
||||
|
||||
var b bytes.Buffer
|
||||
generatorOutput, err := generator.Generate(t, &b, append(h.genOpts, generator.WithFileName(relFilePath))...)
|
||||
if err != nil {
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s generation error: %w", fileName, err)
|
||||
}
|
||||
|
||||
formattedGoCode, err := format.Source(b.Bytes())
|
||||
if err != nil {
|
||||
err = remapErrorList(err, generatorOutput.SourceMap, fileName)
|
||||
return GenerateResult{}, nil, fmt.Errorf("%s source formatting error %w", fileName, err)
|
||||
}
|
||||
|
||||
// Hash output, and write out the file if the goCodeHash has changed.
|
||||
goCodeHash := sha256.Sum256(formattedGoCode)
|
||||
if h.UpsertHash(targetFileName, goCodeHash) {
|
||||
result.Updated = true
|
||||
if err = h.writer(targetFileName, formattedGoCode); err != nil {
|
||||
return result, nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the txt file if it has changed.
|
||||
if h.devMode {
|
||||
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
|
||||
joined := strings.Join(generatorOutput.Literals, "\n")
|
||||
txtHash := sha256.Sum256([]byte(joined))
|
||||
if h.UpsertHash(txtFileName, txtHash) {
|
||||
result.TextUpdated = true
|
||||
if err = os.WriteFile(txtFileName, []byte(joined), 0o644); err != nil {
|
||||
return result, nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether the change would require a recompilation to take effect.
|
||||
h.fileNameToOutputMutex.Lock()
|
||||
defer h.fileNameToOutputMutex.Unlock()
|
||||
previous := h.fileNameToOutput[fileName]
|
||||
if generator.HasChanged(previous, generatorOutput) {
|
||||
result.GoUpdated = true
|
||||
}
|
||||
h.fileNameToOutput[fileName] = generatorOutput
|
||||
}
|
||||
|
||||
parsedDiagnostics, err := parser.Diagnose(t)
|
||||
if err != nil {
|
||||
return result, nil, fmt.Errorf("%s diagnostics error: %w", fileName, err)
|
||||
}
|
||||
|
||||
if h.genSourceMapVis {
|
||||
err = generateSourceMapVisualisation(ctx, fileName, targetFileName, generatorOutput.SourceMap)
|
||||
}
|
||||
|
||||
return result, parsedDiagnostics, err
|
||||
}
|
||||
|
||||
// Takes an error from the formatter and attempts to convert the positions reported in the target file to their positions
|
||||
// in the source file.
|
||||
func remapErrorList(err error, sourceMap *parser.SourceMap, fileName string) error {
|
||||
list, ok := err.(scanner.ErrorList)
|
||||
if !ok || len(list) == 0 {
|
||||
return err
|
||||
}
|
||||
for i, e := range list {
|
||||
// The positions in the source map are off by one line because of the package definition.
|
||||
srcPos, ok := sourceMap.SourcePositionFromTarget(uint32(e.Pos.Line-1), uint32(e.Pos.Column))
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
list[i].Pos = token.Position{
|
||||
Filename: fileName,
|
||||
Offset: int(srcPos.Index),
|
||||
Line: int(srcPos.Line) + 1,
|
||||
Column: int(srcPos.Col),
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
var templContents, goContents []byte
|
||||
var templErr, goErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
templContents, templErr = os.ReadFile(templFileName)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
goContents, goErr = os.ReadFile(goFileName)
|
||||
}()
|
||||
wg.Wait()
|
||||
if templErr != nil {
|
||||
return templErr
|
||||
}
|
||||
if goErr != nil {
|
||||
return templErr
|
||||
}
|
||||
|
||||
targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html"
|
||||
w, err := os.Create(targetFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err)
|
||||
}
|
||||
defer w.Close()
|
||||
b := bufio.NewWriter(w)
|
||||
defer b.Flush()
|
||||
|
||||
return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b)
|
||||
}
|
23
templ/cmd/templ/generatecmd/fatalerror.go
Normal file
23
templ/cmd/templ/generatecmd/fatalerror.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package generatecmd
|
||||
|
||||
type FatalError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e FatalError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e FatalError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e FatalError) Is(target error) bool {
|
||||
_, ok := target.(FatalError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e FatalError) As(target any) bool {
|
||||
_, ok := target.(*FatalError)
|
||||
return ok
|
||||
}
|
39
templ/cmd/templ/generatecmd/main.go
Normal file
39
templ/cmd/templ/generatecmd/main.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"log/slog"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
FileName string
|
||||
FileWriter FileWriterFunc
|
||||
Path string
|
||||
Watch bool
|
||||
WatchPattern string
|
||||
OpenBrowser bool
|
||||
Command string
|
||||
ProxyBind string
|
||||
ProxyPort int
|
||||
Proxy string
|
||||
NotifyProxy bool
|
||||
WorkerCount int
|
||||
GenerateSourceMapVisualisations bool
|
||||
IncludeVersion bool
|
||||
IncludeTimestamp bool
|
||||
// PPROFPort is the port to run the pprof server on.
|
||||
PPROFPort int
|
||||
KeepOrphanedFiles bool
|
||||
Lazy bool
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, log *slog.Logger, args Arguments) (err error) {
|
||||
g, err := NewGenerate(log, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return g.Run(ctx)
|
||||
}
|
170
templ/cmd/templ/generatecmd/main_test.go
Normal file
170
templ/cmd/templ/generatecmd/main_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package generatecmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
t.Run("can generate a file in place", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
|
||||
// Run the generate command.
|
||||
err = Run(context.Background(), log, Arguments{
|
||||
FileName: path.Join(dir, "templates.templ"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run generate command: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.go file was created.
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("templates_templ.go was not created: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("can generate a file in watch mode", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
// Run the generate command.
|
||||
return Run(ctx, log, Arguments{
|
||||
Path: dir,
|
||||
Watch: true,
|
||||
})
|
||||
})
|
||||
|
||||
// Check the templates_templ.go file was created, with backoff.
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(time.Second * time.Duration(i))
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.go"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("template files were not created: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
if err := eg.Wait(); err != nil {
|
||||
t.Fatalf("generate command failed: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.txt file was removed.
|
||||
_, err = os.Stat(path.Join(dir, "templates_templ.txt"))
|
||||
if err == nil {
|
||||
t.Fatalf("templates_templ.txt was not removed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultWatchPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
matches bool
|
||||
}{
|
||||
{
|
||||
name: "empty file names do not match",
|
||||
input: "",
|
||||
matches: false,
|
||||
},
|
||||
{
|
||||
name: "*_templ.txt matches, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\strings_templ.txt`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.txt matches, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/strings_templ.txt",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.templ files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.templ`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.templ files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.templ",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.go files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates_templ.go`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*_templ.go files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates_templ.go",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.go files match, Windows",
|
||||
input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.go`,
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.go files match, Unix",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.go",
|
||||
matches: true,
|
||||
},
|
||||
{
|
||||
name: "*.css files do not match",
|
||||
input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.css",
|
||||
matches: false,
|
||||
},
|
||||
}
|
||||
wpRegexp, err := regexp.Compile(defaultWatchPattern)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to compile default watch pattern: %v", err)
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if wpRegexp.MatchString(test.input) != test.matches {
|
||||
t.Fatalf("expected match of %q to be %v", test.input, test.matches)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
82
templ/cmd/templ/generatecmd/modcheck/modcheck.go
Normal file
82
templ/cmd/templ/generatecmd/modcheck/modcheck.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package modcheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"golang.org/x/mod/modfile"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// WalkUp the directory tree, starting at dir, until we find a directory containing
|
||||
// a go.mod file.
|
||||
func WalkUp(dir string) (string, error) {
|
||||
dir, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
var modFile string
|
||||
for {
|
||||
modFile = filepath.Join(dir, "go.mod")
|
||||
_, err := os.Stat(modFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("failed to stat go.mod file: %w", err)
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
// Move up.
|
||||
prev := dir
|
||||
dir = filepath.Dir(dir)
|
||||
if dir == prev {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// No file found.
|
||||
if modFile == "" {
|
||||
return dir, fmt.Errorf("could not find go.mod file")
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func Check(dir string) error {
|
||||
dir, err := WalkUp(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Found a go.mod file.
|
||||
// Read it and find the templ version.
|
||||
modFile := filepath.Join(dir, "go.mod")
|
||||
m, err := os.ReadFile(modFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read go.mod file: %w", err)
|
||||
}
|
||||
|
||||
mf, err := modfile.Parse(modFile, m, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse go.mod file: %w", err)
|
||||
}
|
||||
if mf.Module.Mod.Path == "github.com/a-h/templ" {
|
||||
// The go.mod file is for templ itself.
|
||||
return nil
|
||||
}
|
||||
for _, r := range mf.Require {
|
||||
if r.Mod.Path == "github.com/a-h/templ" {
|
||||
cmp := semver.Compare(r.Mod.Version, templ.Version())
|
||||
if cmp < 0 {
|
||||
return fmt.Errorf("generator %v is newer than templ version %v found in go.mod file, consider running `go get -u github.com/a-h/templ` to upgrade", templ.Version(), r.Mod.Version)
|
||||
}
|
||||
if cmp > 0 {
|
||||
return fmt.Errorf("generator %v is older than templ version %v found in go.mod file, consider upgrading templ CLI", templ.Version(), r.Mod.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("templ not found in go.mod file, run `go get github.com/a-h/templ` to install it")
|
||||
}
|
47
templ/cmd/templ/generatecmd/modcheck/modcheck_test.go
Normal file
47
templ/cmd/templ/generatecmd/modcheck/modcheck_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package modcheck
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/mod/modfile"
|
||||
)
|
||||
|
||||
func TestPatchGoVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "go 1.20",
|
||||
expected: "1.20",
|
||||
},
|
||||
{
|
||||
input: "go 1.20.123",
|
||||
expected: "1.20.123",
|
||||
},
|
||||
{
|
||||
input: "go 1.20.1",
|
||||
expected: "1.20.1",
|
||||
},
|
||||
{
|
||||
input: "go 1.20rc1",
|
||||
expected: "1.20rc1",
|
||||
},
|
||||
{
|
||||
input: "go 1.15",
|
||||
expected: "1.15",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
input := "module github.com/a-h/templ\n\n" + string(test.input) + "\n" + "toolchain go1.27.9\n"
|
||||
mf, err := modfile.Parse("go.mod", []byte(input), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse go.mod: %v", err)
|
||||
}
|
||||
if test.expected != mf.Go.Version {
|
||||
t.Errorf("expected %q, got %q", test.expected, mf.Go.Version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
284
templ/cmd/templ/generatecmd/proxy/proxy.go
Normal file
284
templ/cmd/templ/generatecmd/proxy/proxy.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
stdlog "log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
|
||||
"github.com/andybalholm/brotli"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed script.js
|
||||
var script string
|
||||
|
||||
type Handler struct {
|
||||
log *slog.Logger
|
||||
URL string
|
||||
Target *url.URL
|
||||
p *httputil.ReverseProxy
|
||||
sse *sse.Handler
|
||||
}
|
||||
|
||||
func getScriptTag(nonce string) string {
|
||||
if nonce != "" {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<script src="/_templ/reload/script.js" nonce="`)
|
||||
sb.WriteString(html.EscapeString(nonce))
|
||||
sb.WriteString(`"></script>`)
|
||||
return sb.String()
|
||||
}
|
||||
return `<script src="/_templ/reload/script.js"></script>`
|
||||
}
|
||||
|
||||
func insertScriptTagIntoBody(nonce, body string) (updated string) {
|
||||
doc, err := goquery.NewDocumentFromReader(strings.NewReader(body))
|
||||
if err != nil {
|
||||
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
|
||||
}
|
||||
doc.Find("body").AppendHtml(getScriptTag(nonce))
|
||||
r, err := doc.Html()
|
||||
if err != nil {
|
||||
return strings.Replace(body, "</body>", getScriptTag(nonce)+"</body>", -1)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type passthroughWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (pwc passthroughWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."
|
||||
|
||||
func (h *Handler) modifyResponse(r *http.Response) error {
|
||||
log := h.log.With(slog.String("url", r.Request.URL.String()))
|
||||
if r.Header.Get("templ-skip-modify") == "true" {
|
||||
log.Debug("Skipping response modification because templ-skip-modify header is set")
|
||||
return nil
|
||||
}
|
||||
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
|
||||
log.Debug("Skipping response modification because content type is not text/html", slog.String("content-type", contentType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up readers and writers.
|
||||
newReader := func(in io.Reader) (out io.Reader, err error) {
|
||||
return in, nil
|
||||
}
|
||||
newWriter := func(out io.Writer) io.WriteCloser {
|
||||
return passthroughWriteCloser{out}
|
||||
}
|
||||
switch r.Header.Get("Content-Encoding") {
|
||||
case "gzip":
|
||||
newReader = func(in io.Reader) (out io.Reader, err error) {
|
||||
return gzip.NewReader(in)
|
||||
}
|
||||
newWriter = func(out io.Writer) io.WriteCloser {
|
||||
return gzip.NewWriter(out)
|
||||
}
|
||||
case "br":
|
||||
newReader = func(in io.Reader) (out io.Reader, err error) {
|
||||
return brotli.NewReader(in), nil
|
||||
}
|
||||
newWriter = func(out io.Writer) io.WriteCloser {
|
||||
return brotli.NewWriter(out)
|
||||
}
|
||||
case "":
|
||||
log.Debug("No content encoding header found")
|
||||
default:
|
||||
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
|
||||
}
|
||||
|
||||
// Read the encoded body.
|
||||
encr, err := newReader(r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Body.Close()
|
||||
body, err := io.ReadAll(encr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update it.
|
||||
csp := r.Header.Get("Content-Security-Policy")
|
||||
updated := insertScriptTagIntoBody(parseNonce(csp), string(body))
|
||||
if log.Enabled(r.Request.Context(), slog.LevelDebug) {
|
||||
if len(updated) == len(body) {
|
||||
log.Debug("Reload script not inserted")
|
||||
} else {
|
||||
log.Debug("Reload script inserted")
|
||||
}
|
||||
}
|
||||
|
||||
// Encode the response.
|
||||
var buf bytes.Buffer
|
||||
encw := newWriter(&buf)
|
||||
_, err = encw.Write([]byte(updated))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = encw.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the response.
|
||||
r.Body = io.NopCloser(&buf)
|
||||
r.ContentLength = int64(buf.Len())
|
||||
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseNonce(csp string) (nonce string) {
|
||||
outer:
|
||||
for _, rawDirective := range strings.Split(csp, ";") {
|
||||
parts := strings.Fields(rawDirective)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
if parts[0] != "script-src" {
|
||||
continue
|
||||
}
|
||||
for _, source := range parts[1:] {
|
||||
source = strings.TrimPrefix(source, "'")
|
||||
source = strings.TrimSuffix(source, "'")
|
||||
if strings.HasPrefix(source, "nonce-") {
|
||||
nonce = source[6:]
|
||||
break outer
|
||||
}
|
||||
}
|
||||
}
|
||||
return nonce
|
||||
}
|
||||
|
||||
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
|
||||
p := httputil.NewSingleHostReverseProxy(target)
|
||||
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
|
||||
p.Transport = &roundTripper{
|
||||
maxRetries: 20,
|
||||
initialDelay: 100 * time.Millisecond,
|
||||
backoffExponent: 1.5,
|
||||
}
|
||||
h = &Handler{
|
||||
log: log,
|
||||
URL: fmt.Sprintf("http://%s:%d", bind, port),
|
||||
Target: target,
|
||||
p: p,
|
||||
sse: sse.New(),
|
||||
}
|
||||
p.ModifyResponse = h.modifyResponse
|
||||
return h
|
||||
}
|
||||
|
||||
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/_templ/reload/script.js" {
|
||||
// Provides a script that reloads the page.
|
||||
w.Header().Add("Content-Type", "text/javascript")
|
||||
_, err := io.WriteString(w, script)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to write script: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/_templ/reload/events" {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// Provides a list of messages including a reload message.
|
||||
p.sse.ServeHTTP(w, r)
|
||||
return
|
||||
case http.MethodPost:
|
||||
// Send a reload message to all connected clients.
|
||||
p.sse.Send("message", "reload")
|
||||
return
|
||||
}
|
||||
http.Error(w, "only GET or POST method allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
p.p.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *Handler) SendSSE(eventType string, data string) {
|
||||
p.sse.Send(eventType, data)
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
maxRetries int
|
||||
initialDelay time.Duration
|
||||
backoffExponent float64
|
||||
}
|
||||
|
||||
func (rt *roundTripper) setShouldSkipResponseModificationHeader(r *http.Request, resp *http.Response) {
|
||||
// Instruct the modifyResponse function to skip modifying the response if the
|
||||
// HTTP request has come from HTMX.
|
||||
if r.Header.Get("HX-Request") != "true" {
|
||||
return
|
||||
}
|
||||
resp.Header.Set("templ-skip-modify", "true")
|
||||
}
|
||||
|
||||
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// Read and buffer the body.
|
||||
var bodyBytes []byte
|
||||
if r.Body != nil && r.Body != http.NoBody {
|
||||
var err error
|
||||
bodyBytes, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Body.Close()
|
||||
}
|
||||
|
||||
// Retry logic.
|
||||
var resp *http.Response
|
||||
var err error
|
||||
for retries := 0; retries < rt.maxRetries; retries++ {
|
||||
// Clone the request and set the body.
|
||||
req := r.Clone(r.Context())
|
||||
if bodyBytes != nil {
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// Execute the request.
|
||||
resp, err = http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
time.Sleep(rt.initialDelay * time.Duration(math.Pow(rt.backoffExponent, float64(retries))))
|
||||
continue
|
||||
}
|
||||
|
||||
rt.setShouldSkipResponseModificationHeader(r, resp)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("max retries reached: %q", r.URL.String())
|
||||
}
|
||||
|
||||
func NotifyProxy(host string, port int) error {
|
||||
urlStr := fmt.Sprintf("http://%s:%d/_templ/reload/events", host, port)
|
||||
req, err := http.NewRequest(http.MethodPost, urlStr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = http.DefaultClient.Do(req)
|
||||
return err
|
||||
}
|
627
templ/cmd/templ/generatecmd/proxy/proxy_test.go
Normal file
627
templ/cmd/templ/generatecmd/proxy/proxy_test.go
Normal file
@@ -0,0 +1,627 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestRoundTripper(t *testing.T) {
|
||||
t.Run("if the HX-Request header is present, set the templ-skip-modify header on the response", func(t *testing.T) {
|
||||
rt := &roundTripper{}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating request: %v", err)
|
||||
}
|
||||
req.Header.Set("HX-Request", "true")
|
||||
resp := &http.Response{Header: make(http.Header)}
|
||||
rt.setShouldSkipResponseModificationHeader(req, resp)
|
||||
if resp.Header.Get("templ-skip-modify") != "true" {
|
||||
t.Errorf("expected templ-skip-modify header to be true, got %v", resp.Header.Get("templ-skip-modify"))
|
||||
}
|
||||
})
|
||||
t.Run("if the HX-Request header is not present, do not set the templ-skip-modify header on the response", func(t *testing.T) {
|
||||
rt := &roundTripper{}
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating request: %v", err)
|
||||
}
|
||||
resp := &http.Response{Header: make(http.Header)}
|
||||
rt.setShouldSkipResponseModificationHeader(req, resp)
|
||||
if resp.Header.Get("templ-skip-modify") != "" {
|
||||
t.Errorf("expected templ-skip-modify header to be empty, got %v", resp.Header.Get("templ-skip-modify"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
t.Run("plain: non-html content is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Content-Length", "16")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "16" {
|
||||
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: if the response contains templ-skip-modify header, it is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`Hello`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html")
|
||||
r.Header.Set("Content-Length", "5")
|
||||
r.Header.Set("templ-skip-modify", "true")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "5" {
|
||||
t.Errorf("expected content length to be 5, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`Hello`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", `<html><body></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag("")) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted with nonce", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
const nonce = "this-is-the-nonce"
|
||||
r.Header.Set("Content-Security-Policy", fmt.Sprintf("script-src 'nonce-%s'", nonce))
|
||||
|
||||
expectedString := insertScriptTagIntoBody(nonce, `<html><body></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag(nonce)) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("plain: body tags get the script inserted ignoring js with body tags", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`<html><body><script>console.log("<body></body>")</script></body></html>`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Length", "26")
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", `<html><body><script>console.log("<body></body>")</script></body></html>`)
|
||||
if !strings.Contains(expectedString, getScriptTag("")) {
|
||||
t.Fatalf("expected the script tag to be inserted, but it wasn't: %q", expectedString)
|
||||
}
|
||||
if !strings.Contains(expectedString, `console.log("<body></body>")`) {
|
||||
t.Fatalf("expected the script tag to be inserted, but mangled the html: %q", expectedString)
|
||||
}
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", len(expectedString)) {
|
||||
t.Errorf("expected content length to be %d, got %v", len(expectedString), r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("gzip: non-html content is not modified", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(strings.NewReader(`{"key": "value"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
// It's not actually gzipped here, but it doesn't matter, it shouldn't get that far.
|
||||
r.Header.Set("Content-Encoding", "gzip")
|
||||
// Similarly, this is not the actual length of the gzipped content.
|
||||
r.Header.Set("Content-Length", "16")
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != "16" {
|
||||
t.Errorf("expected content length to be 16, got %v", r.Header.Get("Content-Length"))
|
||||
}
|
||||
actualBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(`{"key": "value"}`, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("gzip: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
body := `<html><body></body></html>`
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
_, err := gzw.Write([]byte(body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
gzw.Close()
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", body)
|
||||
|
||||
var expectedBytes bytes.Buffer
|
||||
gzw = gzip.NewWriter(&expectedBytes)
|
||||
_, err = gzw.Write([]byte(expectedString))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
gzw.Close()
|
||||
expectedLength := len(expectedBytes.Bytes())
|
||||
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(&buf),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "gzip")
|
||||
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err = h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
|
||||
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
actualBody, err := io.ReadAll(gr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
|
||||
// Arrange
|
||||
body := `<html><body></body></html>`
|
||||
var buf bytes.Buffer
|
||||
brw := brotli.NewWriter(&buf)
|
||||
_, err := brw.Write([]byte(body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
brw.Close()
|
||||
|
||||
expectedString := insertScriptTagIntoBody("", body)
|
||||
|
||||
var expectedBytes bytes.Buffer
|
||||
brw = brotli.NewWriter(&expectedBytes)
|
||||
_, err = brw.Write([]byte(expectedString))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error writing gzip: %v", err)
|
||||
}
|
||||
brw.Close()
|
||||
expectedLength := len(expectedBytes.Bytes())
|
||||
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(&buf),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "br")
|
||||
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))
|
||||
|
||||
// Act
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err = h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
|
||||
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
|
||||
}
|
||||
|
||||
actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading response: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
|
||||
t.Errorf("unexpected response body (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
|
||||
// Arrange 1: create a test proxy server.
|
||||
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
|
||||
dummyServer := httptest.NewServer(http.HandlerFunc(dummyHandler))
|
||||
defer dummyServer.Close()
|
||||
|
||||
u, err := url.Parse(dummyServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing URL: %v", err)
|
||||
}
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
handler := New(log, "0.0.0.0", 0, u)
|
||||
proxyServer := httptest.NewServer(handler)
|
||||
defer proxyServer.Close()
|
||||
|
||||
u2, err := url.Parse(proxyServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing URL: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(u2.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error parsing port: %v", err)
|
||||
}
|
||||
|
||||
// Arrange 2: start a goroutine to listen for sse events.
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error)
|
||||
sseRespCh := make(chan string)
|
||||
sseListening := make(chan bool) // Coordination channel that ensures the SSE listener is started before notifying the proxy.
|
||||
go func() {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/_templ/reload/events", proxyServer.URL), nil)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
sseListening <- true
|
||||
lines := []string{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
if scanner.Text() == "data: reload" {
|
||||
sseRespCh <- strings.Join(lines, "\n")
|
||||
return
|
||||
}
|
||||
}
|
||||
err = scanner.Err()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// Act: notify the proxy.
|
||||
select { // Either SSE is listening or an error occurred.
|
||||
case <-sseListening:
|
||||
err = NotifyProxy(u2.Hostname(), port)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error notifying proxy: %v", err)
|
||||
}
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
t.Fatalf("unexpected sse response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Assert.
|
||||
select { // Either SSE has a expected response or an error or timeout occurred.
|
||||
case resp := <-sseRespCh:
|
||||
if !strings.Contains(resp, "event: message\ndata: reload") {
|
||||
t.Errorf("expected sse reload event to be received, got: %q", resp)
|
||||
}
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
t.Fatalf("unexpected sse response: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for sse response")
|
||||
}
|
||||
})
|
||||
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
|
||||
// Arrange
|
||||
r := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
|
||||
Header: make(http.Header),
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
r.Header.Set("Content-Type", "text/html, charset=utf-8")
|
||||
r.Header.Set("Content-Encoding", "weird-encoding")
|
||||
|
||||
// Act
|
||||
lh := newTestLogHandler(slog.LevelInfo)
|
||||
log := slog.New(lh)
|
||||
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
|
||||
err := h.modifyResponse(r)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Assert
|
||||
if len(lh.records) != 1 {
|
||||
var sb strings.Builder
|
||||
for _, record := range lh.records {
|
||||
sb.WriteString(record.Message)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
t.Fatalf("expected 1 log entry, but got %d: \n%s", len(lh.records), sb.String())
|
||||
}
|
||||
record := lh.records[0]
|
||||
if record.Message != unsupportedContentEncoding {
|
||||
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
|
||||
}
|
||||
if record.Level != slog.LevelWarn {
|
||||
t.Errorf("expected warning, got level %v", record.Level)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newTestLogHandler(level slog.Level) *testLogHandler {
|
||||
return &testLogHandler{
|
||||
m: new(sync.Mutex),
|
||||
records: nil,
|
||||
level: level,
|
||||
}
|
||||
}
|
||||
|
||||
type testLogHandler struct {
|
||||
m *sync.Mutex
|
||||
records []slog.Record
|
||||
level slog.Level
|
||||
}
|
||||
|
||||
func (h *testLogHandler) Enabled(ctx context.Context, l slog.Level) bool {
|
||||
return l >= h.level
|
||||
}
|
||||
|
||||
func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
h.m.Lock()
|
||||
defer h.m.Unlock()
|
||||
if r.Level < h.level {
|
||||
return nil
|
||||
}
|
||||
h.records = append(h.records, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *testLogHandler) WithGroup(name string) slog.Handler {
|
||||
return h
|
||||
}
|
||||
|
||||
func TestParseNonce(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
csp string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty csp",
|
||||
csp: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "simple csp",
|
||||
csp: "script-src 'nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C'",
|
||||
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
},
|
||||
{
|
||||
name: "simple csp without single quote",
|
||||
csp: "script-src nonce-oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
expected: "oLhVst3hTAcxI734qtB0J9Qc7W4qy09C",
|
||||
},
|
||||
{
|
||||
name: "complete csp",
|
||||
csp: "default-src 'self'; frame-ancestors 'self'; form-action 'self'; script-src 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC';",
|
||||
expected: "4VOtk0Uo1l7pwtC",
|
||||
},
|
||||
{
|
||||
name: "mdn example 1",
|
||||
csp: "default-src 'self'",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 2",
|
||||
csp: "default-src 'self' *.trusted.com",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 3",
|
||||
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mdn example 3 multiple sources",
|
||||
csp: "default-src 'self'; img-src *; media-src media1.com media2.com; script-src userscripts.example.com foo.com 'strict-dynamic' 'nonce-4VOtk0Uo1l7pwtC'",
|
||||
expected: "4VOtk0Uo1l7pwtC",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
nonce := parseNonce(tc.csp)
|
||||
if nonce != tc.expected {
|
||||
t.Errorf("expected nonce to be %s, but got %s", tc.expected, nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
10
templ/cmd/templ/generatecmd/proxy/script.js
Normal file
10
templ/cmd/templ/generatecmd/proxy/script.js
Normal file
@@ -0,0 +1,10 @@
|
||||
(function() {
|
||||
let templ_reloadSrc = window.templ_reloadSrc || new EventSource("/_templ/reload/events");
|
||||
templ_reloadSrc.onmessage = (event) => {
|
||||
if (event && event.data === "reload") {
|
||||
window.location.reload();
|
||||
}
|
||||
};
|
||||
window.templ_reloadSrc = templ_reloadSrc;
|
||||
window.onbeforeunload = () => window.templ_reloadSrc.close();
|
||||
})();
|
108
templ/cmd/templ/generatecmd/run/run_test.go
Normal file
108
templ/cmd/templ/generatecmd/run/run_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package run_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/run"
|
||||
)
|
||||
|
||||
//go:embed testprogram/*
|
||||
var testprogram embed.FS
|
||||
|
||||
func TestGoRun(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test in short mode.")
|
||||
}
|
||||
|
||||
// Copy testprogram to a temporary directory.
|
||||
dir, err := os.MkdirTemp("", "testprogram")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make test dir: %v", err)
|
||||
}
|
||||
files, err := testprogram.ReadDir("testprogram")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read embedded dir: %v", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
srcFileName := "testprogram/" + file.Name()
|
||||
srcData, err := testprogram.ReadFile(srcFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
|
||||
}
|
||||
tgtFileName := filepath.Join(dir, file.Name())
|
||||
tgtFile, err := os.Create(tgtFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
|
||||
}
|
||||
defer tgtFile.Close()
|
||||
if _, err := tgtFile.Write(srcData); err != nil {
|
||||
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
|
||||
}
|
||||
}
|
||||
// Rename the go.mod.embed file to go.mod.
|
||||
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
|
||||
t.Fatalf("failed to rename go.mod.embed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd string
|
||||
}{
|
||||
{
|
||||
name: "Well behaved programs get shut down",
|
||||
cmd: "go run .",
|
||||
},
|
||||
{
|
||||
name: "Badly behaved programs get shut down",
|
||||
cmd: "go run . -badly-behaved",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cmd, err := run.Run(ctx, dir, tt.cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run program: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
if err := run.KillAll(); err != nil {
|
||||
t.Fatalf("failed to kill all: %v", err)
|
||||
}
|
||||
|
||||
// Check the parent process is no longer running.
|
||||
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
|
||||
t.Fatalf("process %d is still running", pid)
|
||||
}
|
||||
// Check that the child was stopped.
|
||||
body, err := readResponse("http://localhost:7777")
|
||||
if err == nil {
|
||||
t.Fatalf("child process is still running: %s", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func readResponse(url string) (body string, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
84
templ/cmd/templ/generatecmd/run/run_unix.go
Normal file
84
templ/cmd/templ/generatecmd/run/run_unix.go
Normal file
@@ -0,0 +1,84 @@
|
||||
//go:build unix
|
||||
|
||||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
m = &sync.Mutex{}
|
||||
running = map[string]*exec.Cmd{}
|
||||
)
|
||||
|
||||
func KillAll() (err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
var errs []error
|
||||
for _, cmd := range running {
|
||||
if err := kill(cmd); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
|
||||
}
|
||||
}
|
||||
running = map[string]*exec.Cmd{}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func kill(cmd *exec.Cmd) (err error) {
|
||||
errs := make([]error, 4)
|
||||
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
|
||||
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
|
||||
errs[2] = ignoreExited(cmd.Wait())
|
||||
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func ignoreExited(err error) error {
|
||||
if errors.Is(err, syscall.ESRCH) {
|
||||
return nil
|
||||
}
|
||||
// Ignore *exec.ExitError
|
||||
if _, ok := err.(*exec.ExitError); ok {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
cmd, ok := running[input]
|
||||
if ok {
|
||||
if err := kill(cmd); err != nil {
|
||||
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
|
||||
}
|
||||
|
||||
delete(running, input)
|
||||
}
|
||||
parts := strings.Fields(input)
|
||||
executable := parts[0]
|
||||
args := []string{}
|
||||
if len(parts) > 1 {
|
||||
args = append(args, parts[1:]...)
|
||||
}
|
||||
|
||||
cmd = exec.CommandContext(ctx, executable, args...)
|
||||
// Wait for the process to finish gracefully before termination.
|
||||
cmd.WaitDelay = time.Second * 3
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = workingDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
running[input] = cmd
|
||||
err = cmd.Start()
|
||||
return
|
||||
}
|
69
templ/cmd/templ/generatecmd/run/run_windows.go
Normal file
69
templ/cmd/templ/generatecmd/run/run_windows.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build windows
|
||||
|
||||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var m = &sync.Mutex{}
|
||||
var running = map[string]*exec.Cmd{}
|
||||
|
||||
func KillAll() (err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, cmd := range running {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
err := kill.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
running = map[string]*exec.Cmd{}
|
||||
return
|
||||
}
|
||||
|
||||
func Stop(cmd *exec.Cmd) (err error) {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
return kill.Run()
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, workingDir string, input string) (cmd *exec.Cmd, err error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
cmd, ok := running[input]
|
||||
if ok {
|
||||
kill := exec.Command("TASKKILL", "/T", "/F", "/PID", strconv.Itoa(cmd.Process.Pid))
|
||||
kill.Stderr = os.Stderr
|
||||
kill.Stdout = os.Stdout
|
||||
err := kill.Run()
|
||||
if err != nil {
|
||||
return cmd, err
|
||||
}
|
||||
delete(running, input)
|
||||
}
|
||||
parts := strings.Fields(input)
|
||||
executable := parts[0]
|
||||
args := []string{}
|
||||
if len(parts) > 1 {
|
||||
args = append(args, parts[1:]...)
|
||||
}
|
||||
|
||||
cmd = exec.Command(executable, args...)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Dir = workingDir
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
running[input] = cmd
|
||||
err = cmd.Start()
|
||||
return
|
||||
}
|
3
templ/cmd/templ/generatecmd/run/testprogram/go.mod.embed
Normal file
3
templ/cmd/templ/generatecmd/run/testprogram/go.mod.embed
Normal file
@@ -0,0 +1,3 @@
|
||||
module testprogram
|
||||
|
||||
go 1.23
|
63
templ/cmd/templ/generatecmd/run/testprogram/main.go
Normal file
63
templ/cmd/templ/generatecmd/run/testprogram/main.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This is a test program. It is used only to test the behaviour of the run package.
|
||||
// The run package is supposed to be able to run and stop programs. Those programs may start
|
||||
// child processes, which should also be stopped when the parent program is stopped.
|
||||
|
||||
// For example, running `go run .` will compile an executable and run it.
|
||||
|
||||
// So, this program does nothing. It just waits for a signal to stop.
|
||||
|
||||
// In "Well behaved" mode, the program will stop when it receives a signal.
|
||||
// In "Badly behaved" mode, the program will ignore the signal and continue running.
|
||||
|
||||
// The run package should be able to stop the program in both cases.
|
||||
|
||||
var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
mode := "Well behaved"
|
||||
if *badlyBehavedFlag {
|
||||
mode = "Badly behaved"
|
||||
}
|
||||
fmt.Printf("%s process %d started.\n", mode, os.Getpid())
|
||||
|
||||
// Start a web server on a known port so that we can check that this process is
|
||||
// not running, when it's been started as a child process, and we don't know
|
||||
// its pid.
|
||||
go func() {
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%d", os.Getpid())
|
||||
})
|
||||
err := http.ListenAndServe("127.0.0.1:7777", nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error running web server: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
if !*badlyBehavedFlag {
|
||||
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-sigs:
|
||||
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
fmt.Printf("Process %d still running...\n", os.Getpid())
|
||||
}
|
||||
}
|
||||
}
|
84
templ/cmd/templ/generatecmd/sse/server.go
Normal file
84
templ/cmd/templ/generatecmd/sse/server.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func New() *Handler {
|
||||
return &Handler{
|
||||
m: new(sync.Mutex),
|
||||
requests: map[int64]chan event{},
|
||||
}
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
m *sync.Mutex
|
||||
counter int64
|
||||
requests map[int64]chan event
|
||||
}
|
||||
|
||||
type event struct {
|
||||
Type string
|
||||
Data string
|
||||
}
|
||||
|
||||
// Send an event to all connected clients.
|
||||
func (s *Handler) Send(eventType string, data string) {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
for _, f := range s.requests {
|
||||
f := f
|
||||
go func(f chan event) {
|
||||
f <- event{
|
||||
Type: eventType,
|
||||
Data: data,
|
||||
}
|
||||
}(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
id := atomic.AddInt64(&s.counter, 1)
|
||||
s.m.Lock()
|
||||
events := make(chan event)
|
||||
s.requests[id] = events
|
||||
s.m.Unlock()
|
||||
defer func() {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
delete(s.requests, id)
|
||||
close(events)
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(0)
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
if _, err := fmt.Fprintf(w, "event: message\ndata: ping\n\n"); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
timer.Reset(time.Second * 5)
|
||||
case e := <-events:
|
||||
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Type, e.Data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case <-r.Context().Done():
|
||||
break loop
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
}
|
52
templ/cmd/templ/generatecmd/symlink/symlink_test.go
Normal file
52
templ/cmd/templ/generatecmd/symlink/symlink_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package symlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
)
|
||||
|
||||
func TestSymlink(t *testing.T) {
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
t.Run("can generate if root is symlink", func(t *testing.T) {
|
||||
// templ generate -f templates.templ
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
symlinkPath := dir + "-symlink"
|
||||
err = os.Symlink(dir, symlinkPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create dir symlink: %v", err)
|
||||
}
|
||||
defer os.Remove(symlinkPath)
|
||||
|
||||
// Delete the templates_templ.go file to ensure it is generated.
|
||||
err = os.Remove(path.Join(symlinkPath, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to remove templates_templ.go: %v", err)
|
||||
}
|
||||
|
||||
// Run the generate command.
|
||||
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
|
||||
Path: symlinkPath,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run generate command: %v", err)
|
||||
}
|
||||
|
||||
// Check the templates_templ.go file was created.
|
||||
_, err = os.Stat(path.Join(symlinkPath, "templates_templ.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("templates_templ.go was not created: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
@@ -0,0 +1,101 @@
|
||||
package testeventhandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestErrorLocationMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawFileName string
|
||||
errorPositions []token.Position
|
||||
}{
|
||||
{
|
||||
name: "single error outputs location in srcFile",
|
||||
rawFileName: "single_error.templ.error",
|
||||
errorPositions: []token.Position{
|
||||
{Offset: 46, Line: 3, Column: 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple errors all output locations in srcFile",
|
||||
rawFileName: "multiple_errors.templ.error",
|
||||
errorPositions: []token.Position{
|
||||
{Offset: 41, Line: 3, Column: 15},
|
||||
{Offset: 101, Line: 7, Column: 22},
|
||||
{Offset: 126, Line: 10, Column: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
slog := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
|
||||
var fw generatecmd.FileWriterFunc
|
||||
fseh := generatecmd.NewFSEventHandler(slog, ".", false, []generator.GenerateOpt{}, false, false, fw, false)
|
||||
for _, test := range tests {
|
||||
// The raw files cannot end in .templ because they will cause the generator to fail. Instead,
|
||||
// we create a tmp file that ends in .templ only for the duration of the test.
|
||||
rawFile, err := os.Open(test.rawFileName)
|
||||
if err != nil {
|
||||
t.Errorf("%s: Failed to open file %s: %v", test.name, test.rawFileName, err)
|
||||
break
|
||||
}
|
||||
file, err := os.CreateTemp("", fmt.Sprintf("*%s.templ", test.rawFileName))
|
||||
if err != nil {
|
||||
t.Errorf("%s: Failed to create a tmp file at %s: %v", test.name, file.Name(), err)
|
||||
break
|
||||
}
|
||||
defer os.Remove(file.Name())
|
||||
if _, err = io.Copy(file, rawFile); err != nil {
|
||||
t.Errorf("%s: Failed to copy contents from raw file %s to tmp %s: %v", test.name, test.rawFileName, file.Name(), err)
|
||||
}
|
||||
|
||||
event := fsnotify.Event{Name: file.Name(), Op: fsnotify.Write}
|
||||
_, err = fseh.HandleEvent(context.Background(), event)
|
||||
if err == nil {
|
||||
t.Errorf("%s: no error was thrown", test.name)
|
||||
break
|
||||
}
|
||||
list, ok := err.(scanner.ErrorList)
|
||||
for !ok {
|
||||
err = errors.Unwrap(err)
|
||||
if err == nil {
|
||||
t.Errorf("%s: reached end of error wrapping before finding an ErrorList", test.name)
|
||||
break
|
||||
} else {
|
||||
list, ok = err.(scanner.ErrorList)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if len(list) != len(test.errorPositions) {
|
||||
t.Errorf("%s: expected %d errors but got %d", test.name, len(test.errorPositions), len(list))
|
||||
break
|
||||
}
|
||||
for i, err := range list {
|
||||
test.errorPositions[i].Filename = file.Name()
|
||||
diff := cmp.Diff(test.errorPositions[i], err.Pos)
|
||||
if diff != "" {
|
||||
t.Error(diff)
|
||||
t.Error("expected:")
|
||||
t.Error(test.errorPositions[i])
|
||||
t.Error("actual:")
|
||||
t.Error(err.Pos)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,10 @@
|
||||
package testeventhandler
|
||||
|
||||
func invalid(a: string) string {
|
||||
return "foo"
|
||||
}
|
||||
|
||||
templ multipleError(a: string) {
|
||||
<div/>
|
||||
}
|
||||
l
|
@@ -0,0 +1,5 @@
|
||||
package testeventhandler
|
||||
|
||||
templ singleError(a: string) {
|
||||
<div/>
|
||||
}
|
485
templ/cmd/templ/generatecmd/testwatch/generate_test.go
Normal file
485
templ/cmd/templ/generatecmd/testwatch/generate_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package testwatch
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
|
||||
)
|
||||
|
||||
//go:embed testdata/*
|
||||
var testdata embed.FS
|
||||
|
||||
func createTestProject(moduleRoot string) (dir string, err error) {
|
||||
dir, err = os.MkdirTemp("", "templ_watch_test_*")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to make test dir: %w", err)
|
||||
}
|
||||
files, err := testdata.ReadDir("testdata")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
src := filepath.Join("testdata", file.Name())
|
||||
data, err := testdata.ReadFile(src)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(dir, file.Name())
|
||||
if file.Name() == "go.mod.embed" {
|
||||
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
|
||||
target = filepath.Join(dir, "go.mod")
|
||||
}
|
||||
err = os.WriteFile(target, data, 0660)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func replaceInFile(name, src, tgt string) error {
|
||||
data, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated := strings.Replace(string(data), src, tgt, -1)
|
||||
return os.WriteFile(name, []byte(updated), 0660)
|
||||
}
|
||||
|
||||
func getPort() (port int, err error) {
|
||||
var a *net.TCPAddr
|
||||
if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getHTML(url string) (doc *goquery.Document, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get %q: %w", url, err)
|
||||
}
|
||||
return goquery.NewDocumentFromReader(resp.Body)
|
||||
}
|
||||
|
||||
func TestCanAccessDirect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Assert.
|
||||
doc, err := getHTML(args.AppURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
countText := doc.Find(`div[data-testid="count"]`).Text()
|
||||
actualCount, err := strconv.Atoi(countText)
|
||||
if err != nil {
|
||||
t.Fatalf("got count %q instead of integer", countText)
|
||||
}
|
||||
if actualCount < 1 {
|
||||
t.Errorf("expected count >= 1, got %d", actualCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanAccessViaProxy(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Assert.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
countText := doc.Find(`div[data-testid="count"]`).Text()
|
||||
actualCount, err := strconv.Atoi(countText)
|
||||
if err != nil {
|
||||
t.Fatalf("got count %q instead of integer", countText)
|
||||
}
|
||||
if actualCount < 1 {
|
||||
t.Errorf("expected count >= 1, got %d", actualCount)
|
||||
}
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data string
|
||||
}
|
||||
|
||||
func readSSE(ctx context.Context, url string, sse chan<- Event) (err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var e Event
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
sse <- e
|
||||
e = Event{}
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
e.Type = line[len("event: "):]
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
e.Data = line[len("data: "):]
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func TestFileModificationsResultInSSEWithGzip(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Start the SSE check.
|
||||
events := make(chan Event)
|
||||
var eventsErr error
|
||||
go func() {
|
||||
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
|
||||
}()
|
||||
|
||||
// Assert data is expected.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
|
||||
t.Errorf("expected %q, got %q", "Original", text)
|
||||
}
|
||||
|
||||
// Change file.
|
||||
templFile := filepath.Join(args.AppDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification">Updated</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
// Give the filesystem watcher a few seconds.
|
||||
var reloadCount int
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Data == "reload" {
|
||||
reloadCount++
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if reloadCount == 0 {
|
||||
t.Error("failed to receive SSE about update after 5 seconds")
|
||||
}
|
||||
|
||||
// Check to see if there were any errors.
|
||||
if eventsErr != nil {
|
||||
t.Errorf("error reading events: %v", err)
|
||||
}
|
||||
|
||||
// See results in browser immediately.
|
||||
doc, err = getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
|
||||
t.Errorf("expected %q, got %q", "Updated", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileModificationsResultInSSE(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
args, teardown, err := Setup(false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
|
||||
// Start the SSE check.
|
||||
events := make(chan Event)
|
||||
var eventsErr error
|
||||
go func() {
|
||||
eventsErr = readSSE(context.Background(), fmt.Sprintf("%s/_templ/reload/events", args.ProxyURL), events)
|
||||
}()
|
||||
|
||||
// Assert data is expected.
|
||||
doc, err := getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Original" {
|
||||
t.Errorf("expected %q, got %q", "Original", text)
|
||||
}
|
||||
|
||||
// Change file.
|
||||
templFile := filepath.Join(args.AppDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification">Updated</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
// Give the filesystem watcher a few seconds.
|
||||
var reloadCount int
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Data == "reload" {
|
||||
reloadCount++
|
||||
break loop
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if reloadCount == 0 {
|
||||
t.Error("failed to receive SSE about update after 5 seconds")
|
||||
}
|
||||
|
||||
// Check to see if there were any errors.
|
||||
if eventsErr != nil {
|
||||
t.Errorf("error reading events: %v", err)
|
||||
}
|
||||
|
||||
// See results in browser immediately.
|
||||
doc, err = getHTML(args.ProxyURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read HTML: %v", err)
|
||||
}
|
||||
if text := doc.Find(`div[data-testid="modification"]`).Text(); text != "Updated" {
|
||||
t.Errorf("expected %q, got %q", "Updated", text)
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestArgs(modRoot, appDir string, appPort int, proxyBind string, proxyPort int) TestArgs {
|
||||
return TestArgs{
|
||||
ModRoot: modRoot,
|
||||
AppDir: appDir,
|
||||
AppPort: appPort,
|
||||
AppURL: fmt.Sprintf("http://localhost:%d", appPort),
|
||||
ProxyBind: proxyBind,
|
||||
ProxyPort: proxyPort,
|
||||
ProxyURL: fmt.Sprintf("http://%s:%d", proxyBind, proxyPort),
|
||||
}
|
||||
}
|
||||
|
||||
type TestArgs struct {
|
||||
ModRoot string
|
||||
AppDir string
|
||||
AppPort int
|
||||
AppURL string
|
||||
ProxyBind string
|
||||
ProxyPort int
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
func Setup(gzipEncoding bool) (args TestArgs, teardown func(t *testing.T), err error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("could not find working dir: %w", err)
|
||||
}
|
||||
moduleRoot, err := modcheck.WalkUp(wd)
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
|
||||
}
|
||||
|
||||
appDir, err := createTestProject(moduleRoot)
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to create test project: %v", err)
|
||||
}
|
||||
appPort, err := getPort()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
|
||||
}
|
||||
proxyPort, err := getPort()
|
||||
if err != nil {
|
||||
return args, teardown, fmt.Errorf("failed to get available port: %v", err)
|
||||
}
|
||||
proxyBind := "localhost"
|
||||
|
||||
args = NewTestArgs(moduleRoot, appDir, appPort, proxyBind, proxyPort)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var cmdErr error
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
command := fmt.Sprintf("go run . -port %d", args.AppPort)
|
||||
if gzipEncoding {
|
||||
command += " -gzip true"
|
||||
}
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
cmdErr = generatecmd.Run(ctx, log, generatecmd.Arguments{
|
||||
Path: appDir,
|
||||
Watch: true,
|
||||
OpenBrowser: false,
|
||||
Command: command,
|
||||
ProxyBind: proxyBind,
|
||||
ProxyPort: proxyPort,
|
||||
Proxy: args.AppURL,
|
||||
NotifyProxy: false,
|
||||
WorkerCount: 0,
|
||||
GenerateSourceMapVisualisations: false,
|
||||
IncludeVersion: false,
|
||||
IncludeTimestamp: false,
|
||||
PPROFPort: 0,
|
||||
KeepOrphanedFiles: false,
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait for server to start.
|
||||
if err = waitForURL(args.AppURL); err != nil {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return args, teardown, fmt.Errorf("failed to start app server, command error %v: %w", cmdErr, err)
|
||||
}
|
||||
if err = waitForURL(args.ProxyURL); err != nil {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
return args, teardown, fmt.Errorf("failed to start proxy server, command error %v: %w", cmdErr, err)
|
||||
}
|
||||
|
||||
// Wait for exit.
|
||||
teardown = func(t *testing.T) {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if cmdErr != nil {
|
||||
t.Errorf("failed to run generate cmd: %v", err)
|
||||
}
|
||||
|
||||
if err = os.RemoveAll(appDir); err != nil {
|
||||
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
|
||||
}
|
||||
}
|
||||
return args, teardown, err
|
||||
}
|
||||
|
||||
func waitForURL(url string) (err error) {
|
||||
var tries int
|
||||
for {
|
||||
time.Sleep(time.Second)
|
||||
if tries > 20 {
|
||||
return err
|
||||
}
|
||||
tries++
|
||||
var resp *http.Response
|
||||
resp, err = http.Get(url)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to get %q: %v\n", url, err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Printf("failed to get %q: %v\n", url, err)
|
||||
err = fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateReturnsErrors(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("could not find working dir: %v", err)
|
||||
}
|
||||
moduleRoot, err := modcheck.WalkUp(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("could not find local templ go.mod file: %v", err)
|
||||
}
|
||||
|
||||
appDir, err := createTestProject(moduleRoot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err = os.RemoveAll(appDir); err != nil {
|
||||
t.Fatalf("failed to remove test dir %q: %v", appDir, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Break the HTML.
|
||||
templFile := filepath.Join(appDir, "templates.templ")
|
||||
err = replaceInFile(templFile,
|
||||
`<div data-testid="modification">Original</div>`,
|
||||
`<div data-testid="modification" -unclosed div-</div>`)
|
||||
if err != nil {
|
||||
t.Errorf("failed to replace text in file: %v", err)
|
||||
}
|
||||
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
|
||||
// Run.
|
||||
err = generatecmd.Run(context.Background(), log, generatecmd.Arguments{
|
||||
Path: appDir,
|
||||
Watch: false,
|
||||
IncludeVersion: false,
|
||||
IncludeTimestamp: false,
|
||||
KeepOrphanedFiles: false,
|
||||
})
|
||||
if err == nil {
|
||||
t.Errorf("expected generation error, got %v", err)
|
||||
}
|
||||
}
|
7
templ/cmd/templ/generatecmd/testwatch/testdata/go.mod.embed
vendored
Normal file
7
templ/cmd/templ/generatecmd/testwatch/testdata/go.mod.embed
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
module templ/testproject
|
||||
|
||||
go 1.23
|
||||
|
||||
require github.com/a-h/templ v0.2.513 // indirect
|
||||
|
||||
replace github.com/a-h/templ => {moduleRoot}
|
2
templ/cmd/templ/generatecmd/testwatch/testdata/go.sum
vendored
Normal file
2
templ/cmd/templ/generatecmd/testwatch/testdata/go.sum
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
81
templ/cmd/templ/generatecmd/testwatch/testdata/main.go
vendored
Normal file
81
templ/cmd/templ/generatecmd/testwatch/testdata/main.go
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
|
||||
type GzipResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) Write(b []byte) (int, error) {
|
||||
var buf bytes.Buffer
|
||||
gzw := gzip.NewWriter(&buf)
|
||||
defer gzw.Close()
|
||||
|
||||
_, err := gzw.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = gzw.Close()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
|
||||
return w.w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func (w *GzipResponseWriter) WriteHeader(statusCode int) {
|
||||
w.w.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
|
||||
var useGzip = flag.Bool("gzip", false, "Toggle gzip encoding")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *flagPort == 0 {
|
||||
fmt.Println("missing port flag")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var count int
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if useGzip != nil && *useGzip {
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w = &GzipResponseWriter{w: w}
|
||||
}
|
||||
|
||||
count++
|
||||
c := Page(count)
|
||||
h := templ.Handler(c)
|
||||
h.ErrorHandler = func(r *http.Request, err error) http.Handler {
|
||||
slog.Error("failed to render template", slog.Any("error", err))
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error listening: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
17
templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ
vendored
Normal file
17
templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ Page(count int) {
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>templ test page</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Count</h1>
|
||||
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
|
||||
<div data-testid="modification">Original</div>
|
||||
</body>
|
||||
</html>
|
||||
}
|
55
templ/cmd/templ/generatecmd/testwatch/testdata/templates_templ.go
vendored
Normal file
55
templ/cmd/templ/generatecmd/testwatch/testdata/templates_templ.go
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package main
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
import "fmt"
|
||||
|
||||
func Page(count int) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var2 string
|
||||
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/generatecmd/testwatch/testdata/templates.templ`, Line: 13, Col: 54}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
166
templ/cmd/templ/generatecmd/watcher/watch.go
Normal file
166
templ/cmd/templ/generatecmd/watcher/watch.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
func Recursive(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
watchPattern *regexp.Regexp,
|
||||
out chan fsnotify.Event,
|
||||
errors chan error,
|
||||
) (w *RecursiveWatcher, err error) {
|
||||
fsnw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors)
|
||||
go w.loop()
|
||||
return w, w.Add(path)
|
||||
}
|
||||
|
||||
func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher {
|
||||
return &RecursiveWatcher{
|
||||
ctx: ctx,
|
||||
w: w,
|
||||
WatchPattern: watchPattern,
|
||||
Events: events,
|
||||
Errors: errors,
|
||||
timers: make(map[timerKey]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
// WalkFiles walks the file tree rooted at path, sending a Create event for each
|
||||
// file it encounters.
|
||||
func WalkFiles(ctx context.Context, path string, watchPattern *regexp.Regexp, out chan fsnotify.Event) (err error) {
|
||||
rootPath := path
|
||||
fileSystem := os.DirFS(rootPath)
|
||||
return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
absPath, err := filepath.Abs(filepath.Join(rootPath, path))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if info.IsDir() && shouldSkipDir(absPath) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !watchPattern.MatchString(absPath) {
|
||||
return nil
|
||||
}
|
||||
out <- fsnotify.Event{
|
||||
Name: absPath,
|
||||
Op: fsnotify.Create,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type RecursiveWatcher struct {
|
||||
ctx context.Context
|
||||
w *fsnotify.Watcher
|
||||
WatchPattern *regexp.Regexp
|
||||
Events chan fsnotify.Event
|
||||
Errors chan error
|
||||
timerMu sync.Mutex
|
||||
timers map[timerKey]*time.Timer
|
||||
}
|
||||
|
||||
type timerKey struct {
|
||||
name string
|
||||
op fsnotify.Op
|
||||
}
|
||||
|
||||
func timerKeyFromEvent(event fsnotify.Event) timerKey {
|
||||
return timerKey{
|
||||
name: event.Name,
|
||||
op: event.Op,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) Close() error {
|
||||
return w.w.Close()
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case event, ok := <-w.w.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Create) {
|
||||
if err := w.Add(event.Name); err != nil {
|
||||
w.Errors <- err
|
||||
}
|
||||
}
|
||||
// Only notify on templ related files.
|
||||
if !w.WatchPattern.MatchString(event.Name) {
|
||||
continue
|
||||
}
|
||||
tk := timerKeyFromEvent(event)
|
||||
w.timerMu.Lock()
|
||||
t, ok := w.timers[tk]
|
||||
w.timerMu.Unlock()
|
||||
if !ok {
|
||||
t = time.AfterFunc(100*time.Millisecond, func() {
|
||||
w.Events <- event
|
||||
})
|
||||
w.timerMu.Lock()
|
||||
w.timers[tk] = t
|
||||
w.timerMu.Unlock()
|
||||
continue
|
||||
}
|
||||
t.Reset(100 * time.Millisecond)
|
||||
case err, ok := <-w.w.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.Errors <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *RecursiveWatcher) Add(dir string) error {
|
||||
return filepath.WalkDir(dir, func(dir string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if shouldSkipDir(dir) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return w.w.Add(dir)
|
||||
})
|
||||
}
|
||||
|
||||
func shouldSkipDir(dir string) bool {
|
||||
if dir == "." {
|
||||
return false
|
||||
}
|
||||
if dir == "vendor" || dir == "node_modules" {
|
||||
return true
|
||||
}
|
||||
_, name := path.Split(dir)
|
||||
// These directories are ignored by the Go tool.
|
||||
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
133
templ/cmd/templ/generatecmd/watcher/watch_test.go
Normal file
133
templ/cmd/templ/generatecmd/watcher/watch_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
func TestWatchDebouncesDuplicates(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(300 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
|
||||
tests := []struct {
|
||||
event1 fsnotify.Event
|
||||
event2 fsnotify.Event
|
||||
}{
|
||||
// Different files
|
||||
{fsnotify.Event{Name: "test.templ"}, fsnotify.Event{Name: "test2.templ"}},
|
||||
// Different operations
|
||||
{
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Write},
|
||||
},
|
||||
// Different operations and files
|
||||
{
|
||||
fsnotify.Event{Name: "test.templ", Op: fsnotify.Create},
|
||||
fsnotify.Event{Name: "test2.templ", Op: fsnotify.Write},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- test.event1
|
||||
rw.w.Events <- test.event2
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(300 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w := &fsnotify.Watcher{
|
||||
Events: make(chan fsnotify.Event),
|
||||
}
|
||||
events := make(chan fsnotify.Event, 2)
|
||||
errors := make(chan error)
|
||||
watchPattern, err := regexp.Compile(".*")
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
|
||||
}
|
||||
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
|
||||
go func() {
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
<-time.After(200 * time.Millisecond)
|
||||
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
|
||||
cancel()
|
||||
close(rw.w.Events)
|
||||
}()
|
||||
rw.loop()
|
||||
count := 0
|
||||
exp := time.After(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-rw.Events:
|
||||
count++
|
||||
case <-exp:
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 event, got %d", count)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
174
templ/cmd/templ/imports/process.go
Normal file
174
templ/cmd/templ/imports/process.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package imports
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/token"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
goparser "go/parser"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/imports"
|
||||
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
)
|
||||
|
||||
var internalImports = []string{"github.com/a-h/templ", "github.com/a-h/templ/runtime"}
|
||||
|
||||
func convertTemplToGoURI(templURI string) (isTemplFile bool, goURI string) {
|
||||
base, fileName := path.Split(templURI)
|
||||
if !strings.HasSuffix(fileName, ".templ") {
|
||||
return
|
||||
}
|
||||
return true, base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go")
|
||||
}
|
||||
|
||||
var fset = token.NewFileSet()
|
||||
|
||||
func updateImports(name, src string) (updated []*ast.ImportSpec, err error) {
|
||||
// Apply auto imports.
|
||||
updatedGoCode, err := imports.Process(name, []byte(src), nil)
|
||||
if err != nil {
|
||||
return updated, fmt.Errorf("failed to process go code %q: %w", src, err)
|
||||
}
|
||||
// Get updated imports.
|
||||
gofile, err := goparser.ParseFile(fset, name, updatedGoCode, goparser.ImportsOnly)
|
||||
if err != nil {
|
||||
return updated, fmt.Errorf("failed to get imports from updated go code: %w", err)
|
||||
}
|
||||
for _, imp := range gofile.Imports {
|
||||
if !slices.Contains(internalImports, strings.Trim(imp.Path.Value, "\"")) {
|
||||
updated = append(updated, imp)
|
||||
}
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func Process(t parser.TemplateFile) (parser.TemplateFile, error) {
|
||||
if t.Filepath == "" {
|
||||
return t, nil
|
||||
}
|
||||
isTemplFile, fileName := convertTemplToGoURI(t.Filepath)
|
||||
if !isTemplFile {
|
||||
return t, fmt.Errorf("invalid filepath: %s", t.Filepath)
|
||||
}
|
||||
|
||||
// The first node always contains existing imports.
|
||||
// If there isn't one, create it.
|
||||
if len(t.Nodes) == 0 {
|
||||
t.Nodes = append(t.Nodes, parser.TemplateFileGoExpression{})
|
||||
}
|
||||
// If there is one, ensure it is a Go expression.
|
||||
if _, ok := t.Nodes[0].(parser.TemplateFileGoExpression); !ok {
|
||||
t.Nodes = append([]parser.TemplateFileNode{parser.TemplateFileGoExpression{}}, t.Nodes...)
|
||||
}
|
||||
|
||||
// Find all existing imports.
|
||||
importsNode := t.Nodes[0].(parser.TemplateFileGoExpression)
|
||||
|
||||
// Generate code.
|
||||
gw := bytes.NewBuffer(nil)
|
||||
var updatedImports []*ast.ImportSpec
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() (err error) {
|
||||
if _, err := generator.Generate(t, gw); err != nil {
|
||||
return fmt.Errorf("failed to generate go code: %w", err)
|
||||
}
|
||||
updatedImports, err = updateImports(fileName, gw.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get imports from generated go code: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
var firstGoNodeInTemplate *ast.File
|
||||
// Update the template with the imports.
|
||||
// Ensure that there is a Go expression to add the imports to as the first node.
|
||||
eg.Go(func() (err error) {
|
||||
firstGoNodeInTemplate, err = goparser.ParseFile(fset, fileName, t.Package.Expression.Value+"\n"+importsNode.Expression.Value, goparser.AllErrors|goparser.ParseComments)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse imports section: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Wait for completion of both parts.
|
||||
if err := eg.Wait(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
// Delete unused imports.
|
||||
for _, imp := range firstGoNodeInTemplate.Imports {
|
||||
if !containsImport(updatedImports, imp) {
|
||||
name, path, err := getImportDetails(imp)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
}
|
||||
// Add imports, if there are any to add.
|
||||
for _, imp := range updatedImports {
|
||||
if !containsImport(firstGoNodeInTemplate.Imports, imp) {
|
||||
name, path, err := getImportDetails(imp)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
}
|
||||
// Edge case: reinsert the import to use import syntax without parentheses.
|
||||
if len(firstGoNodeInTemplate.Imports) == 1 {
|
||||
name, path, err := getImportDetails(firstGoNodeInTemplate.Imports[0])
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
// Write out the Go code with the imports.
|
||||
updatedGoCode := new(strings.Builder)
|
||||
err := format.Node(updatedGoCode, fset, firstGoNodeInTemplate)
|
||||
if err != nil {
|
||||
return t, fmt.Errorf("failed to write updated go code: %w", err)
|
||||
}
|
||||
// Remove the package statement from the node, by cutting the first line of the file.
|
||||
importsNode.Expression.Value = strings.TrimSpace(strings.SplitN(updatedGoCode.String(), "\n", 2)[1])
|
||||
if len(updatedImports) == 0 && importsNode.Expression.Value == "" {
|
||||
t.Nodes = t.Nodes[1:]
|
||||
return t, nil
|
||||
}
|
||||
t.Nodes[0] = importsNode
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func getImportDetails(imp *ast.ImportSpec) (name, importPath string, err error) {
|
||||
if imp.Name != nil {
|
||||
name = imp.Name.Name
|
||||
}
|
||||
if imp.Path != nil {
|
||||
importPath, err = strconv.Unquote(imp.Path.Value)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unquote package path %s: %w", imp.Path.Value, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return name, importPath, nil
|
||||
}
|
||||
|
||||
func containsImport(imports []*ast.ImportSpec, spec *ast.ImportSpec) bool {
|
||||
for _, imp := range imports {
|
||||
if imp.Path.Value == spec.Path.Value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
154
templ/cmd/templ/imports/process_test.go
Normal file
154
templ/cmd/templ/imports/process_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package imports
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/tools/txtar"
|
||||
)
|
||||
|
||||
func TestFormatting(t *testing.T) {
|
||||
files, _ := filepath.Glob("testdata/*.txtar")
|
||||
if len(files) == 0 {
|
||||
t.Errorf("no test files found")
|
||||
}
|
||||
for _, file := range files {
|
||||
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||
a, err := txtar.ParseFile(file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse txtar file: %v", err)
|
||||
}
|
||||
if len(a.Files) != 2 {
|
||||
t.Fatalf("expected 2 files, got %d", len(a.Files))
|
||||
}
|
||||
template, err := parser.ParseString(clean(a.Files[0].Data))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse %v", err)
|
||||
}
|
||||
template.Filepath = a.Files[0].Name
|
||||
tf, err := Process(template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process file: %v", err)
|
||||
}
|
||||
expected := string(a.Files[1].Data)
|
||||
actual := new(strings.Builder)
|
||||
if err := tf.Write(actual); err != nil {
|
||||
t.Fatalf("failed to write template file: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, actual.String()); diff != "" {
|
||||
t.Errorf("%s:\n%s", file, diff)
|
||||
t.Errorf("expected:\n%s", showWhitespace(expected))
|
||||
t.Errorf("actual:\n%s", showWhitespace(actual.String()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func showWhitespace(s string) string {
|
||||
s = strings.ReplaceAll(s, "\n", "⏎\n")
|
||||
s = strings.ReplaceAll(s, "\t", "→")
|
||||
s = strings.ReplaceAll(s, " ", "·")
|
||||
return s
|
||||
}
|
||||
|
||||
func clean(b []byte) string {
|
||||
b = bytes.ReplaceAll(b, []byte("$\n"), []byte("\n"))
|
||||
b = bytes.TrimSuffix(b, []byte("\n"))
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestImport(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
assertions func(t *testing.T, updated string)
|
||||
}{
|
||||
{
|
||||
name: "un-named imports are removed",
|
||||
src: `package main
|
||||
|
||||
import "fmt"
|
||||
import "github.com/a-h/templ/cmd/templ/testproject/css-classes"
|
||||
|
||||
templ Page(count int) {
|
||||
{ fmt.Sprintf("%d", count) }
|
||||
{ cssclasses.Header }
|
||||
}
|
||||
`,
|
||||
assertions: func(t *testing.T, updated string) {
|
||||
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 0 {
|
||||
t.Errorf("expected un-named import to be removed, but got %d instance of it", count)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "named imports are retained",
|
||||
src: `package main
|
||||
|
||||
import "fmt"
|
||||
import cssclasses "github.com/a-h/templ/cmd/templ/testproject/css-classes"
|
||||
|
||||
templ Page(count int) {
|
||||
{ fmt.Sprintf("%d", count) }
|
||||
{ cssclasses.Header }
|
||||
}
|
||||
`,
|
||||
assertions: func(t *testing.T, updated string) {
|
||||
if count := strings.Count(updated, "cssclasses \"github.com/a-h/templ/cmd/templ/testproject/css-classes\""); count != 1 {
|
||||
t.Errorf("expected named import to be retained, got %d instances of it", count)
|
||||
}
|
||||
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 1 {
|
||||
t.Errorf("expected one import, got %d", count)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
// Create test project.
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Load the templates.templ file.
|
||||
filePath := path.Join(dir, "templates.templ")
|
||||
err = os.WriteFile(filePath, []byte(test.src), 0660)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
// Parse the new file.
|
||||
template, err := parser.Parse(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse %v", err)
|
||||
}
|
||||
template.Filepath = filePath
|
||||
tf, err := Process(template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process file: %v", err)
|
||||
}
|
||||
|
||||
// Write it back out after processing.
|
||||
buf := new(strings.Builder)
|
||||
if err := tf.Write(buf); err != nil {
|
||||
t.Fatalf("failed to write template file: %v", err)
|
||||
}
|
||||
|
||||
// Assert.
|
||||
test.assertions(t, buf.String())
|
||||
}
|
||||
}
|
12
templ/cmd/templ/imports/testdata/comments.txtar
vendored
Normal file
12
templ/cmd/templ/imports/testdata/comments.txtar
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
28
templ/cmd/templ/imports/testdata/commentsbeforepackage.txtar
vendored
Normal file
28
templ/cmd/templ/imports/testdata/commentsbeforepackage.txtar
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
-- fmt_templ.templ --
|
||||
// Comments before.
|
||||
/*
|
||||
Some more comments
|
||||
*/
|
||||
package test
|
||||
|
||||
templ test() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
// Comments before.
|
||||
/*
|
||||
Some more comments
|
||||
*/
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ test() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
14
templ/cmd/templ/imports/testdata/deleteimports.txtar
vendored
Normal file
14
templ/cmd/templ/imports/testdata/deleteimports.txtar
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import "strconv"
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
15
templ/cmd/templ/imports/testdata/extraspace.txtar
vendored
Normal file
15
templ/cmd/templ/imports/testdata/extraspace.txtar
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
const x = 123
|
||||
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
const x = 123
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
22
templ/cmd/templ/imports/testdata/groups.txtar
vendored
Normal file
22
templ/cmd/templ/imports/testdata/groups.txtar
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"fmt"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))
|
21
templ/cmd/templ/imports/testdata/groupsmanynewlines.txtar
vendored
Normal file
21
templ/cmd/templ/imports/testdata/groupsmanynewlines.txtar
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strconv.Quote("Hello"))
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strconv.Quote("Hello"))
|
10
templ/cmd/templ/imports/testdata/header.txtar
vendored
Normal file
10
templ/cmd/templ/imports/testdata/header.txtar
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
19
templ/cmd/templ/imports/testdata/namedimportsadd.txtar
vendored
Normal file
19
templ/cmd/templ/imports/testdata/namedimportsadd.txtar
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf(sconv.Quote("Hello"))
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf(sconv.Quote("Hello"))
|
16
templ/cmd/templ/imports/testdata/namedimportsremoved.txtar
vendored
Normal file
16
templ/cmd/templ/imports/testdata/namedimportsremoved.txtar
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
12
templ/cmd/templ/imports/testdata/noimports.txtar
vendored
Normal file
12
templ/cmd/templ/imports/testdata/noimports.txtar
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
20
templ/cmd/templ/imports/testdata/noimportscode.txtar
vendored
Normal file
20
templ/cmd/templ/imports/testdata/noimportscode.txtar
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
func test() {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
func test() {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
14
templ/cmd/templ/imports/testdata/stringexp.txtar
vendored
Normal file
14
templ/cmd/templ/imports/testdata/stringexp.txtar
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello(name string) {
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ Hello(name string) {
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
}
|
21
templ/cmd/templ/imports/testdata/twoimports.txtar
vendored
Normal file
21
templ/cmd/templ/imports/testdata/twoimports.txtar
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello(name string) {
|
||||
<div id={ strconv.Atoi("123") }>
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
templ Hello(name string) {
|
||||
<div id={ strconv.Atoi("123") }>
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
</div>
|
||||
}
|
157
templ/cmd/templ/infocmd/main.go
Normal file
157
templ/cmd/templ/infocmd/main.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package infocmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/pls"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
JSON bool `flag:"json" help:"Output info as JSON."`
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
OS struct {
|
||||
GOOS string `json:"goos"`
|
||||
GOARCH string `json:"goarch"`
|
||||
} `json:"os"`
|
||||
Go ToolInfo `json:"go"`
|
||||
Gopls ToolInfo `json:"gopls"`
|
||||
Templ ToolInfo `json:"templ"`
|
||||
}
|
||||
|
||||
type ToolInfo struct {
|
||||
Location string `json:"location"`
|
||||
Version string `json:"version"`
|
||||
OK bool `json:"ok"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func getGoInfo() (d ToolInfo) {
|
||||
// Find Go.
|
||||
var err error
|
||||
d.Location, err = exec.LookPath("go")
|
||||
if err != nil {
|
||||
d.Message = fmt.Sprintf("failed to find go: %v", err)
|
||||
return
|
||||
}
|
||||
// Run go to find the version.
|
||||
cmd := exec.Command(d.Location, "version")
|
||||
v, err := cmd.Output()
|
||||
if err != nil {
|
||||
d.Message = fmt.Sprintf("failed to get go version, check that Go is installed: %v", err)
|
||||
return
|
||||
}
|
||||
d.Version = strings.TrimSpace(string(v))
|
||||
d.OK = true
|
||||
return
|
||||
}
|
||||
|
||||
func getGoplsInfo() (d ToolInfo) {
|
||||
var err error
|
||||
d.Location, err = pls.FindGopls()
|
||||
if err != nil {
|
||||
d.Message = fmt.Sprintf("failed to find gopls: %v", err)
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(d.Location, "version")
|
||||
v, err := cmd.Output()
|
||||
if err != nil {
|
||||
d.Message = fmt.Sprintf("failed to get gopls version: %v", err)
|
||||
return
|
||||
}
|
||||
d.Version = strings.TrimSpace(string(v))
|
||||
d.OK = true
|
||||
return
|
||||
}
|
||||
|
||||
func getTemplInfo() (d ToolInfo) {
|
||||
// Find templ.
|
||||
var err error
|
||||
d.Location, err = findTempl()
|
||||
if err != nil {
|
||||
d.Message = err.Error()
|
||||
return
|
||||
}
|
||||
// Run templ to find the version.
|
||||
cmd := exec.Command(d.Location, "version")
|
||||
v, err := cmd.Output()
|
||||
if err != nil {
|
||||
d.Message = fmt.Sprintf("failed to get templ version: %v", err)
|
||||
return
|
||||
}
|
||||
d.Version = strings.TrimSpace(string(v))
|
||||
if d.Version != templ.Version() {
|
||||
d.Message = fmt.Sprintf("version mismatch - you're running %q at the command line, but the version in the path is %q", templ.Version(), d.Version)
|
||||
return
|
||||
}
|
||||
d.OK = true
|
||||
return
|
||||
}
|
||||
|
||||
func findTempl() (location string, err error) {
|
||||
executableName := "templ"
|
||||
if runtime.GOOS == "windows" {
|
||||
executableName = "templ.exe"
|
||||
}
|
||||
executableName, err = exec.LookPath(executableName)
|
||||
if err == nil {
|
||||
// Found on the path.
|
||||
return executableName, nil
|
||||
}
|
||||
|
||||
// Unexpected error.
|
||||
if !errors.Is(err, exec.ErrNotFound) {
|
||||
return "", fmt.Errorf("unexpected error looking for templ: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("templ is not in the path (%q). You can install templ with `go install github.com/a-h/templ/cmd/templ@latest`", os.Getenv("PATH"))
|
||||
}
|
||||
|
||||
func getInfo() (d Info) {
|
||||
d.OS.GOOS = runtime.GOOS
|
||||
d.OS.GOARCH = runtime.GOARCH
|
||||
d.Go = getGoInfo()
|
||||
d.Gopls = getGoplsInfo()
|
||||
d.Templ = getTemplInfo()
|
||||
return
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, log *slog.Logger, stdout io.Writer, args Arguments) (err error) {
|
||||
info := getInfo()
|
||||
if args.JSON {
|
||||
enc := json.NewEncoder(stdout)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(info)
|
||||
}
|
||||
log.Info("os", slog.String("goos", info.OS.GOOS), slog.String("goarch", info.OS.GOARCH))
|
||||
logInfo(ctx, log, "go", info.Go)
|
||||
logInfo(ctx, log, "gopls", info.Gopls)
|
||||
logInfo(ctx, log, "templ", info.Templ)
|
||||
return nil
|
||||
}
|
||||
|
||||
func logInfo(ctx context.Context, log *slog.Logger, name string, ti ToolInfo) {
|
||||
level := slog.LevelInfo
|
||||
if !ti.OK {
|
||||
level = slog.LevelError
|
||||
}
|
||||
args := []any{
|
||||
slog.String("location", ti.Location),
|
||||
slog.String("version", ti.Version),
|
||||
}
|
||||
if ti.Message != "" {
|
||||
args = append(args, slog.String("message", ti.Message))
|
||||
}
|
||||
log.Log(ctx, level, name, args...)
|
||||
}
|
130
templ/cmd/templ/lspcmd/httpdebug/handler.go
Normal file
130
templ/cmd/templ/lspcmd/httpdebug/handler.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package httpdebug
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/proxy"
|
||||
"github.com/a-h/templ/cmd/templ/visualize"
|
||||
)
|
||||
|
||||
var log *slog.Logger
|
||||
|
||||
func NewHandler(l *slog.Logger, s *proxy.Server) http.Handler {
|
||||
m := http.NewServeMux()
|
||||
log = l
|
||||
m.HandleFunc("/templ", func(w http.ResponseWriter, r *http.Request) {
|
||||
uri := r.URL.Query().Get("uri")
|
||||
c, ok := s.TemplSource.Get(uri)
|
||||
if !ok {
|
||||
Error(w, "uri not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
String(w, c.String())
|
||||
})
|
||||
m.HandleFunc("/sourcemap", func(w http.ResponseWriter, r *http.Request) {
|
||||
uri := r.URL.Query().Get("uri")
|
||||
sm, ok := s.SourceMapCache.Get(uri)
|
||||
if !ok {
|
||||
Error(w, "uri not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
JSON(w, sm.SourceLinesToTarget)
|
||||
})
|
||||
m.HandleFunc("/go", func(w http.ResponseWriter, r *http.Request) {
|
||||
uri := r.URL.Query().Get("uri")
|
||||
c, ok := s.GoSource[uri]
|
||||
if !ok {
|
||||
Error(w, "uri not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
String(w, c)
|
||||
})
|
||||
m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
uri := r.URL.Query().Get("uri")
|
||||
if uri == "" {
|
||||
// List all URIs.
|
||||
if err := list(s.TemplSource.URIs()).Render(r.Context(), w); err != nil {
|
||||
Error(w, "failed to list URIs", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Assume we've got a URI.
|
||||
templSource, ok := s.TemplSource.Get(uri)
|
||||
if !ok {
|
||||
if !ok {
|
||||
Error(w, "uri not found in document contents", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
goSource, ok := s.GoSource[uri]
|
||||
if !ok {
|
||||
if !ok {
|
||||
Error(w, "uri not found in document contents", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
sm, ok := s.SourceMapCache.Get(uri)
|
||||
if !ok {
|
||||
Error(w, "uri not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err := visualize.HTML(uri, templSource.String(), goSource, sm).Render(r.Context(), w); err != nil {
|
||||
Error(w, "failed to visualize HTML", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func getMapURL(uri string) templ.SafeURL {
|
||||
return withQuery("/", uri)
|
||||
}
|
||||
|
||||
func getSourceMapURL(uri string) templ.SafeURL {
|
||||
return withQuery("/sourcemap", uri)
|
||||
}
|
||||
|
||||
func getTemplURL(uri string) templ.SafeURL {
|
||||
return withQuery("/templ", uri)
|
||||
}
|
||||
|
||||
func getGoURL(uri string) templ.SafeURL {
|
||||
return withQuery("/go", uri)
|
||||
}
|
||||
|
||||
func withQuery(path, uri string) templ.SafeURL {
|
||||
q := make(url.Values)
|
||||
q.Set("uri", uri)
|
||||
u := &url.URL{
|
||||
Path: path,
|
||||
RawPath: path,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
return templ.SafeURL(u.String())
|
||||
}
|
||||
|
||||
func JSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(v); err != nil {
|
||||
log.Error("failed to write JSON response", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
func String(w http.ResponseWriter, s string) {
|
||||
if _, err := io.WriteString(w, s); err != nil {
|
||||
log.Error("failed to write string response", slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
func Error(w http.ResponseWriter, msg string, status int) {
|
||||
w.WriteHeader(status)
|
||||
if _, err := io.WriteString(w, msg); err != nil {
|
||||
log.Error("failed to write error response", slog.Any("error", err))
|
||||
}
|
||||
}
|
22
templ/cmd/templ/lspcmd/httpdebug/list.templ
Normal file
22
templ/cmd/templ/lspcmd/httpdebug/list.templ
Normal file
@@ -0,0 +1,22 @@
|
||||
package httpdebug
|
||||
|
||||
templ list(uris []string) {
|
||||
<table>
|
||||
<tr>
|
||||
<th>File</th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
</tr>
|
||||
for _, uri := range uris {
|
||||
<tr>
|
||||
<td>{ uri }</td>
|
||||
<td><a href={ getMapURL(uri) }>Mapping</a></td>
|
||||
<td><a href={ getSourceMapURL(uri) }>Source Map</a></td>
|
||||
<td><a href={ getTemplURL(uri) }>Templ</a></td>
|
||||
<td><a href={ getGoURL(uri) }>Go</a></td>
|
||||
</tr>
|
||||
}
|
||||
</table>
|
||||
}
|
99
templ/cmd/templ/lspcmd/httpdebug/list_templ.go
Normal file
99
templ/cmd/templ/lspcmd/httpdebug/list_templ.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package httpdebug
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
func list(uris []string) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<table><tr><th>File</th><th></th><th></th><th></th><th></th></tr>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
for _, uri := range uris {
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "<tr><td>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var2 string
|
||||
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(uri)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/lspcmd/httpdebug/list.templ`, Line: 14, Col: 13}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</td><td><a href=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var3 templ.SafeURL = getMapURL(uri)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var3)))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "\">Mapping</a></td><td><a href=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var4 templ.SafeURL = getSourceMapURL(uri)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var4)))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "\">Source Map</a></td><td><a href=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var5 templ.SafeURL = getTemplURL(uri)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var5)))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "\">Templ</a></td><td><a href=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var6 templ.SafeURL = getGoURL(uri)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(string(templ_7745c5c3_Var6)))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "\">Go</a></td></tr>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</table>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
957
templ/cmd/templ/lspcmd/lsp_test.go
Normal file
957
templ/cmd/templ/lspcmd/lsp_test.go
Normal file
@@ -0,0 +1,957 @@
|
||||
package lspcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd/modcheck"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/lspdiff"
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
"github.com/a-h/templ/lsp/jsonrpc2"
|
||||
"github.com/a-h/templ/lsp/protocol"
|
||||
"github.com/a-h/templ/lsp/uri"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestCompletion(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
|
||||
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
defer cancel()
|
||||
|
||||
templFile, err := os.ReadFile(appDir + "/templates.templ")
|
||||
if err != nil {
|
||||
t.Errorf("failed to read file %q: %v", appDir+"/templates.templ", err)
|
||||
return
|
||||
}
|
||||
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
|
||||
TextDocument: protocol.TextDocumentItem{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
LanguageID: "templ",
|
||||
Version: 1,
|
||||
Text: string(templFile),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to register open file: %v", err)
|
||||
return
|
||||
}
|
||||
log.Info("Calling completion")
|
||||
|
||||
globalSnippetsLen := 1
|
||||
|
||||
// Edit the file.
|
||||
// Replace:
|
||||
// <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
|
||||
// With various tests:
|
||||
// <div data-testid="count">{ f
|
||||
tests := []struct {
|
||||
line int
|
||||
replacement string
|
||||
cursor string
|
||||
assert func(t *testing.T, cl *protocol.CompletionList) (msg string, ok bool)
|
||||
}{
|
||||
{
|
||||
line: 13,
|
||||
replacement: ` <div data-testid="count">{ `,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
|
||||
if actual != nil && len(actual.Items) != globalSnippetsLen {
|
||||
return "expected completion list to be empty", false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
line: 13,
|
||||
replacement: ` <div data-testid="count">{ fmt.`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
|
||||
if !lspdiff.CompletionListContainsText(actual, "fmt.Sprintf") {
|
||||
return fmt.Sprintf("expected fmt.Sprintf to be in the completion list, but got %#v", actual), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
line: 13,
|
||||
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d",`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.CompletionList) (msg string, ok bool) {
|
||||
if actual != nil && len(actual.Items) != globalSnippetsLen {
|
||||
return "expected completion list to be empty", false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
|
||||
// Edit the file.
|
||||
updated := testproject.MustReplaceLine(string(templFile), test.line, test.replacement)
|
||||
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
|
||||
TextDocument: protocol.VersionedTextDocumentIdentifier{
|
||||
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
Version: int32(i + 2),
|
||||
},
|
||||
ContentChanges: []protocol.TextDocumentContentChangeEvent{
|
||||
{
|
||||
Range: nil,
|
||||
Text: updated,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to change file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Give CI/CD pipeline executors some time because they're often quite slow.
|
||||
var ok bool
|
||||
var msg string
|
||||
for i := 0; i < 3; i++ {
|
||||
actual, err := server.Completion(ctx, &protocol.CompletionParams{
|
||||
Context: &protocol.CompletionContext{
|
||||
TriggerCharacter: ".",
|
||||
TriggerKind: protocol.CompletionTriggerKindTriggerCharacter,
|
||||
},
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
// Positions are zero indexed.
|
||||
Position: protocol.Position{
|
||||
Line: uint32(test.line - 1),
|
||||
Character: uint32(len(test.cursor) - 1),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to get completion: %v", err)
|
||||
return
|
||||
}
|
||||
msg, ok = test.assert(t, actual)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
}
|
||||
if !ok {
|
||||
t.Error(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
log.Info("Completed test")
|
||||
}
|
||||
|
||||
func TestHover(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
|
||||
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
defer cancel()
|
||||
|
||||
templFile, err := os.ReadFile(appDir + "/templates.templ")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file %q: %v", appDir+"/templates.templ", err)
|
||||
}
|
||||
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
|
||||
TextDocument: protocol.TextDocumentItem{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
LanguageID: "templ",
|
||||
Version: 1,
|
||||
Text: string(templFile),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to register open file: %v", err)
|
||||
return
|
||||
}
|
||||
log.Info("Calling hover")
|
||||
|
||||
// Edit the file.
|
||||
// Replace:
|
||||
// <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
|
||||
// With various tests:
|
||||
// <div data-testid="count">{ f
|
||||
tests := []struct {
|
||||
line int
|
||||
replacement string
|
||||
cursor string
|
||||
assert func(t *testing.T, hr *protocol.Hover) (msg string, ok bool)
|
||||
}{
|
||||
{
|
||||
line: 13,
|
||||
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
|
||||
expectedHover := protocol.Hover{
|
||||
Contents: protocol.MarkupContent{
|
||||
Kind: "markdown",
|
||||
Value: "```go\npackage fmt\n```\n\n---\n\n[`fmt` on pkg.go.dev](https://pkg.go.dev/fmt)",
|
||||
},
|
||||
}
|
||||
if diff := lspdiff.Hover(expectedHover, *actual); diff != "" {
|
||||
return fmt.Sprintf("unexpected hover: %v\n\n: markdown: %#v", diff, actual.Contents.Value), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
line: 13,
|
||||
replacement: ` <div data-testid="count">{ fmt.Sprintf("%d", count) }</div>`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
|
||||
expectedHover := protocol.Hover{
|
||||
Contents: protocol.MarkupContent{
|
||||
Kind: "markdown",
|
||||
Value: "```go\nfunc fmt.Sprintf(format string, a ...any) string\n```\n\n---\n\nSprintf formats according to a format specifier and returns the resulting string.\n\n\n---\n\n[`fmt.Sprintf` on pkg.go.dev](https://pkg.go.dev/fmt#Sprintf)",
|
||||
},
|
||||
}
|
||||
if actual == nil {
|
||||
return "expected hover to be non-nil", false
|
||||
}
|
||||
if diff := lspdiff.Hover(expectedHover, *actual); diff != "" {
|
||||
return fmt.Sprintf("unexpected hover: %v", diff), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
line: 19,
|
||||
replacement: `var nihao = "你好"`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
|
||||
// There's nothing to hover, just want to make sure it doesn't panic.
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
line: 19,
|
||||
replacement: `var nihao = "你好"`,
|
||||
cursor: ` ^`, // Your text editor might not render this well, but it's the hao.
|
||||
assert: func(t *testing.T, actual *protocol.Hover) (msg string, ok bool) {
|
||||
// There's nothing to hover, just want to make sure it doesn't panic.
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
|
||||
// Put the file back to the initial point.
|
||||
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
|
||||
TextDocument: protocol.VersionedTextDocumentIdentifier{
|
||||
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
Version: int32(i + 2),
|
||||
},
|
||||
ContentChanges: []protocol.TextDocumentContentChangeEvent{
|
||||
{
|
||||
Range: nil,
|
||||
Text: string(templFile),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to change file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Give CI/CD pipeline executors some time because they're often quite slow.
|
||||
var ok bool
|
||||
var msg string
|
||||
for i := 0; i < 3; i++ {
|
||||
lspCharIndex, err := runeIndexToUTF8ByteIndex(test.replacement, len(test.cursor)-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
actual, err := server.Hover(ctx, &protocol.HoverParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
// Positions are zero indexed.
|
||||
Position: protocol.Position{
|
||||
Line: uint32(test.line - 1),
|
||||
Character: lspCharIndex,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to hover: %v", err)
|
||||
return
|
||||
}
|
||||
msg, ok = test.assert(t, actual)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
}
|
||||
if !ok {
|
||||
t.Error(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReferences(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
|
||||
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
return
|
||||
}
|
||||
defer teardown(t)
|
||||
defer cancel()
|
||||
|
||||
log.Info("Calling References")
|
||||
|
||||
tests := []struct {
|
||||
line int
|
||||
character int
|
||||
filename string
|
||||
assert func(t *testing.T, l []protocol.Location) (msg string, ok bool)
|
||||
}{
|
||||
{
|
||||
// this is the definition of the templ function in the templates.templ file.
|
||||
line: 5,
|
||||
character: 9,
|
||||
filename: "/templates.templ",
|
||||
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
|
||||
expectedReference := []protocol.Location{
|
||||
{
|
||||
// This is the usage of the templ function in the main.go file.
|
||||
URI: uri.URI("file://" + appDir + "/main.go"),
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(24),
|
||||
Character: uint32(7),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(24),
|
||||
Character: uint32(11),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if diff := lspdiff.References(expectedReference, actual); diff != "" {
|
||||
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
// this is the definition of the struct in the templates.templ file.
|
||||
line: 21,
|
||||
character: 9,
|
||||
filename: "/templates.templ",
|
||||
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
|
||||
expectedReference := []protocol.Location{
|
||||
{
|
||||
// This is the usage of the struct in the templates.templ file.
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(24),
|
||||
Character: uint32(8),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(24),
|
||||
Character: uint32(14),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if diff := lspdiff.References(expectedReference, actual); diff != "" {
|
||||
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
{
|
||||
// this test is for inclusions from a remote file that has not been explicitly called with didOpen
|
||||
line: 3,
|
||||
character: 9,
|
||||
filename: "/remotechild.templ",
|
||||
assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) {
|
||||
expectedReference := []protocol.Location{
|
||||
{
|
||||
URI: uri.URI("file://" + appDir + "/remoteparent.templ"),
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(3),
|
||||
Character: uint32(2),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(3),
|
||||
Character: uint32(8),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
URI: uri.URI("file://" + appDir + "/remoteparent.templ"),
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(7),
|
||||
Character: uint32(2),
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(7),
|
||||
Character: uint32(8),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if diff := lspdiff.References(expectedReference, actual); diff != "" {
|
||||
return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
|
||||
// Give CI/CD pipeline executors some time because they're often quite slow.
|
||||
var ok bool
|
||||
var msg string
|
||||
for i := 0; i < 3; i++ {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
actual, err := server.References(ctx, &protocol.ReferenceParams{
|
||||
TextDocumentPositionParams: protocol.TextDocumentPositionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + test.filename),
|
||||
},
|
||||
// Positions are zero indexed.
|
||||
Position: protocol.Position{
|
||||
Line: uint32(test.line - 1),
|
||||
Character: uint32(test.character - 1),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to get references: %v", err)
|
||||
return
|
||||
}
|
||||
msg, ok = test.assert(t, actual)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
}
|
||||
if !ok {
|
||||
t.Error(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodeAction(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
|
||||
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
defer cancel()
|
||||
|
||||
templFile, err := os.ReadFile(appDir + "/templates.templ")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read file %q: %v", appDir+"/templates.templ", err)
|
||||
}
|
||||
err = server.DidOpen(ctx, &protocol.DidOpenTextDocumentParams{
|
||||
TextDocument: protocol.TextDocumentItem{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
LanguageID: "templ",
|
||||
Version: 1,
|
||||
Text: string(templFile),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to register open file: %v", err)
|
||||
return
|
||||
}
|
||||
log.Info("Calling codeAction")
|
||||
|
||||
tests := []struct {
|
||||
line int
|
||||
replacement string
|
||||
cursor string
|
||||
assert func(t *testing.T, hr []protocol.CodeAction) (msg string, ok bool)
|
||||
}{
|
||||
{
|
||||
line: 25,
|
||||
replacement: `var s = Struct{}`,
|
||||
cursor: ` ^`,
|
||||
assert: func(t *testing.T, actual []protocol.CodeAction) (msg string, ok bool) {
|
||||
var expected []protocol.CodeAction
|
||||
// To support code actions, update cmd/templ/lspcmd/proxy/server.go and add the
|
||||
// Title (e.g. Organize Imports, or Fill Struct) to the supportedCodeActions map.
|
||||
|
||||
// Some Code Actions are simple edits, so all that is needed is for the server
|
||||
// to remap the source code positions.
|
||||
|
||||
// However, other Code Actions are commands, where the arguments must be rewritten
|
||||
// and will need to be handled individually.
|
||||
if diff := lspdiff.CodeAction(expected, actual); diff != "" {
|
||||
return fmt.Sprintf("unexpected codeAction: %v", diff), false
|
||||
}
|
||||
return "", true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
|
||||
// Put the file back to the initial point.
|
||||
err = server.DidChange(ctx, &protocol.DidChangeTextDocumentParams{
|
||||
TextDocument: protocol.VersionedTextDocumentIdentifier{
|
||||
TextDocumentIdentifier: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
Version: int32(i + 2),
|
||||
},
|
||||
ContentChanges: []protocol.TextDocumentContentChangeEvent{
|
||||
{
|
||||
Range: nil,
|
||||
Text: string(templFile),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to change file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Give CI/CD pipeline executors some time because they're often quite slow.
|
||||
var ok bool
|
||||
var msg string
|
||||
for i := 0; i < 3; i++ {
|
||||
lspCharIndex, err := runeIndexToUTF8ByteIndex(test.replacement, len(test.cursor)-1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
actual, err := server.CodeAction(ctx, &protocol.CodeActionParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI("file://" + appDir + "/templates.templ"),
|
||||
},
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{
|
||||
Line: uint32(test.line - 1),
|
||||
Character: lspCharIndex,
|
||||
},
|
||||
End: protocol.Position{
|
||||
Line: uint32(test.line - 1),
|
||||
Character: lspCharIndex + 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed code action: %v", err)
|
||||
return
|
||||
}
|
||||
msg, ok = test.assert(t, actual)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
}
|
||||
if !ok {
|
||||
t.Error(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentSymbol(t *testing.T) {
|
||||
if testing.Short() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
|
||||
ctx, appDir, _, server, teardown, err := Setup(ctx, log)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup test: %v", err)
|
||||
}
|
||||
defer teardown(t)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
uri string
|
||||
expect []protocol.SymbolInformationOrDocumentSymbol
|
||||
}{
|
||||
{
|
||||
uri: "file://" + appDir + "/templates.templ",
|
||||
expect: []protocol.SymbolInformationOrDocumentSymbol{
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "Page",
|
||||
Kind: protocol.SymbolKindFunction,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 11, Character: 0},
|
||||
End: protocol.Position{Line: 50, Character: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "nihao",
|
||||
Kind: protocol.SymbolKindVariable,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 18, Character: 4},
|
||||
End: protocol.Position{Line: 18, Character: 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "Struct",
|
||||
Kind: protocol.SymbolKindStruct,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 20, Character: 5},
|
||||
End: protocol.Position{Line: 22, Character: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "s",
|
||||
Kind: protocol.SymbolKindVariable,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 24, Character: 4},
|
||||
End: protocol.Position{Line: 24, Character: 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
uri: "file://" + appDir + "/remoteparent.templ",
|
||||
expect: []protocol.SymbolInformationOrDocumentSymbol{
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "RemoteInclusionTest",
|
||||
Kind: protocol.SymbolKindFunction,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 9, Character: 0},
|
||||
End: protocol.Position{Line: 35, Character: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
SymbolInformation: &protocol.SymbolInformation{
|
||||
Name: "Remote2",
|
||||
Kind: protocol.SymbolKindFunction,
|
||||
Location: protocol.Location{
|
||||
Range: protocol.Range{
|
||||
Start: protocol.Position{Line: 37, Character: 0},
|
||||
End: protocol.Position{Line: 63, Character: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
|
||||
actual, err := server.DocumentSymbol(ctx, &protocol.DocumentSymbolParams{
|
||||
TextDocument: protocol.TextDocumentIdentifier{
|
||||
URI: uri.URI(test.uri),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("failed to get document symbol: %v", err)
|
||||
}
|
||||
|
||||
// Set expected URI.
|
||||
for i, v := range test.expect {
|
||||
if v.SymbolInformation != nil {
|
||||
v.SymbolInformation.Location.URI = uri.URI(test.uri)
|
||||
test.expect[i] = v
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to convert expect to any slice: %v", err)
|
||||
}
|
||||
diff := cmp.Diff(test.expect, actual)
|
||||
if diff != "" {
|
||||
t.Errorf("unexpected document symbol: %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runeIndexToUTF8ByteIndex(s string, runeIndex int) (lspChar uint32, err error) {
|
||||
for i, r := range []rune(s) {
|
||||
if i == runeIndex {
|
||||
break
|
||||
}
|
||||
l := utf8.RuneLen(r)
|
||||
if l < 0 {
|
||||
return 0, fmt.Errorf("invalid rune in string at index %d", runeIndex)
|
||||
}
|
||||
lspChar += uint32(l)
|
||||
}
|
||||
return lspChar, nil
|
||||
}
|
||||
|
||||
func NewTestClient(log *slog.Logger) TestClient {
|
||||
return TestClient{
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type TestClient struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func (tc TestClient) Progress(ctx context.Context, params *protocol.ProgressParams) (err error) {
|
||||
tc.log.Info("client: Received Progress", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) (err error) {
|
||||
tc.log.Info("client: Received WorkDoneProgressCreate", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) LogMessage(ctx context.Context, params *protocol.LogMessageParams) (err error) {
|
||||
tc.log.Info("client: Received LogMessage", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) PublishDiagnostics(ctx context.Context, params *protocol.PublishDiagnosticsParams) (err error) {
|
||||
tc.log.Info("client: Received PublishDiagnostics", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) ShowMessage(ctx context.Context, params *protocol.ShowMessageParams) (err error) {
|
||||
tc.log.Info("client: Received ShowMessage", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) ShowMessageRequest(ctx context.Context, params *protocol.ShowMessageRequestParams) (result *protocol.MessageActionItem, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (tc TestClient) Telemetry(ctx context.Context, params any) (err error) {
|
||||
tc.log.Info("client: Received Telemetry", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) RegisterCapability(ctx context.Context, params *protocol.RegistrationParams,
|
||||
) (err error) {
|
||||
tc.log.Info("client: Received RegisterCapability", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) UnregisterCapability(ctx context.Context, params *protocol.UnregistrationParams) (err error) {
|
||||
tc.log.Info("client: Received UnregisterCapability", slog.Any("params", params))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tc TestClient) ApplyEdit(ctx context.Context, params *protocol.ApplyWorkspaceEditParams) (result *protocol.ApplyWorkspaceEditResponse, err error) {
|
||||
tc.log.Info("client: Received ApplyEdit", slog.Any("params", params))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (tc TestClient) Configuration(ctx context.Context, params *protocol.ConfigurationParams) (result []any, err error) {
|
||||
tc.log.Info("client: Received Configuration", slog.Any("params", params))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (tc TestClient) WorkspaceFolders(ctx context.Context) (result []protocol.WorkspaceFolder, err error) {
|
||||
tc.log.Info("client: Received WorkspaceFolders")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func Setup(ctx context.Context, log *slog.Logger) (clientCtx context.Context, appDir string, client protocol.Client, server protocol.Server, teardown func(t *testing.T), err error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ctx, appDir, client, server, teardown, fmt.Errorf("could not find working dir: %w", err)
|
||||
}
|
||||
moduleRoot, err := modcheck.WalkUp(wd)
|
||||
if err != nil {
|
||||
return ctx, appDir, client, server, teardown, fmt.Errorf("could not find local templ go.mod file: %v", err)
|
||||
}
|
||||
|
||||
appDir, err = testproject.Create(moduleRoot)
|
||||
if err != nil {
|
||||
return ctx, appDir, client, server, teardown, fmt.Errorf("failed to create test project: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var cmdErr error
|
||||
|
||||
// Copy from the LSP to the Client, and vice versa.
|
||||
fromClient, toLSP := io.Pipe()
|
||||
fromLSP, toClient := io.Pipe()
|
||||
clientStream := jsonrpc2.NewStream(newStdRwc(log, "clientStream", toLSP, fromLSP))
|
||||
serverStream := jsonrpc2.NewStream(newStdRwc(log, "serverStream", toClient, fromClient))
|
||||
|
||||
// Create the client that the server needs.
|
||||
client = NewTestClient(log)
|
||||
ctx, _, server = protocol.NewClient(ctx, client, clientStream, log)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Info("Running")
|
||||
// Create the server that the client needs.
|
||||
cmdErr = run(ctx, log, serverStream, Arguments{})
|
||||
if cmdErr != nil {
|
||||
log.Error("Failed to run", slog.Any("error", cmdErr))
|
||||
}
|
||||
log.Info("Stopped")
|
||||
}()
|
||||
|
||||
// Initialize.
|
||||
ir, err := server.Initialize(ctx, &protocol.InitializeParams{
|
||||
ClientInfo: &protocol.ClientInfo{},
|
||||
Capabilities: protocol.ClientCapabilities{
|
||||
Workspace: &protocol.WorkspaceClientCapabilities{
|
||||
ApplyEdit: true,
|
||||
WorkspaceEdit: &protocol.WorkspaceClientCapabilitiesWorkspaceEdit{
|
||||
DocumentChanges: true,
|
||||
},
|
||||
WorkspaceFolders: true,
|
||||
FileOperations: &protocol.WorkspaceClientCapabilitiesFileOperations{
|
||||
DidCreate: true,
|
||||
WillCreate: true,
|
||||
DidRename: true,
|
||||
WillRename: true,
|
||||
DidDelete: true,
|
||||
WillDelete: true,
|
||||
},
|
||||
},
|
||||
TextDocument: &protocol.TextDocumentClientCapabilities{
|
||||
Synchronization: &protocol.TextDocumentSyncClientCapabilities{
|
||||
DidSave: true,
|
||||
},
|
||||
Completion: &protocol.CompletionTextDocumentClientCapabilities{
|
||||
CompletionItem: &protocol.CompletionTextDocumentClientCapabilitiesItem{
|
||||
SnippetSupport: true,
|
||||
DeprecatedSupport: true,
|
||||
InsertReplaceSupport: true,
|
||||
},
|
||||
},
|
||||
Hover: &protocol.HoverTextDocumentClientCapabilities{},
|
||||
SignatureHelp: &protocol.SignatureHelpTextDocumentClientCapabilities{},
|
||||
Declaration: &protocol.DeclarationTextDocumentClientCapabilities{},
|
||||
Definition: &protocol.DefinitionTextDocumentClientCapabilities{},
|
||||
TypeDefinition: &protocol.TypeDefinitionTextDocumentClientCapabilities{},
|
||||
Implementation: &protocol.ImplementationTextDocumentClientCapabilities{},
|
||||
References: &protocol.ReferencesTextDocumentClientCapabilities{},
|
||||
DocumentHighlight: &protocol.DocumentHighlightClientCapabilities{},
|
||||
DocumentSymbol: &protocol.DocumentSymbolClientCapabilities{},
|
||||
CodeAction: &protocol.CodeActionClientCapabilities{},
|
||||
CodeLens: &protocol.CodeLensClientCapabilities{},
|
||||
Formatting: &protocol.DocumentFormattingClientCapabilities{},
|
||||
RangeFormatting: &protocol.DocumentRangeFormattingClientCapabilities{},
|
||||
OnTypeFormatting: &protocol.DocumentOnTypeFormattingClientCapabilities{},
|
||||
PublishDiagnostics: &protocol.PublishDiagnosticsClientCapabilities{},
|
||||
Rename: &protocol.RenameClientCapabilities{},
|
||||
FoldingRange: &protocol.FoldingRangeClientCapabilities{},
|
||||
SelectionRange: &protocol.SelectionRangeClientCapabilities{},
|
||||
CallHierarchy: &protocol.CallHierarchyClientCapabilities{},
|
||||
SemanticTokens: &protocol.SemanticTokensClientCapabilities{},
|
||||
LinkedEditingRange: &protocol.LinkedEditingRangeClientCapabilities{},
|
||||
},
|
||||
Window: &protocol.WindowClientCapabilities{},
|
||||
General: &protocol.GeneralClientCapabilities{},
|
||||
Experimental: nil,
|
||||
},
|
||||
WorkspaceFolders: []protocol.WorkspaceFolder{
|
||||
{
|
||||
URI: "file://" + appDir,
|
||||
Name: "templ-test",
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("Failed to init", slog.Any("error", err))
|
||||
}
|
||||
if ir.ServerInfo.Name != "templ-lsp" {
|
||||
return ctx, appDir, client, server, teardown, fmt.Errorf("expected server name to be templ-lsp, got %q", ir.ServerInfo.Name)
|
||||
}
|
||||
|
||||
// Confirm initialization.
|
||||
log.Info("Confirming initialization...")
|
||||
if err = server.Initialized(ctx, &protocol.InitializedParams{}); err != nil {
|
||||
return ctx, appDir, client, server, teardown, fmt.Errorf("failed to confirm initialization: %v", err)
|
||||
}
|
||||
log.Info("Initialized")
|
||||
|
||||
// Wait for exit.
|
||||
teardown = func(t *testing.T) {
|
||||
log.Info("Tearing down LSP")
|
||||
wg.Wait()
|
||||
if cmdErr != nil {
|
||||
t.Errorf("failed to run lsp cmd: %v", err)
|
||||
}
|
||||
|
||||
if err = os.RemoveAll(appDir); err != nil {
|
||||
t.Errorf("failed to remove test dir %q: %v", appDir, err)
|
||||
}
|
||||
}
|
||||
return ctx, appDir, client, server, teardown, err
|
||||
}
|
42
templ/cmd/templ/lspcmd/lspdiff/lspdiff.go
Normal file
42
templ/cmd/templ/lspcmd/lspdiff/lspdiff.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package lspdiff
|
||||
|
||||
import (
|
||||
"github.com/a-h/templ/lsp/protocol"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
// This package provides a way to compare LSP protocol messages, ignoring irrelevant fields.
|
||||
|
||||
func Hover(expected, actual protocol.Hover) string {
|
||||
return cmp.Diff(expected, actual,
|
||||
cmpopts.IgnoreFields(protocol.Hover{}, "Range"),
|
||||
cmpopts.IgnoreFields(protocol.MarkupContent{}, "Kind"),
|
||||
)
|
||||
}
|
||||
|
||||
func CodeAction(expected, actual []protocol.CodeAction) string {
|
||||
return cmp.Diff(expected, actual)
|
||||
}
|
||||
|
||||
func CompletionList(expected, actual *protocol.CompletionList) string {
|
||||
return cmp.Diff(expected, actual,
|
||||
cmpopts.IgnoreFields(protocol.CompletionList{}, "IsIncomplete"),
|
||||
)
|
||||
}
|
||||
|
||||
func References(expected, actual []protocol.Location) string {
|
||||
return cmp.Diff(expected, actual)
|
||||
}
|
||||
|
||||
func CompletionListContainsText(cl *protocol.CompletionList, text string) bool {
|
||||
if cl == nil {
|
||||
return false
|
||||
}
|
||||
for _, item := range cl.Items {
|
||||
if item.Label == text {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
131
templ/cmd/templ/lspcmd/main.go
Normal file
131
templ/cmd/templ/lspcmd/main.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package lspcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/httpdebug"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/pls"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd/proxy"
|
||||
"github.com/a-h/templ/lsp/jsonrpc2"
|
||||
"github.com/a-h/templ/lsp/protocol"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
type Arguments struct {
|
||||
Log string
|
||||
GoplsLog string
|
||||
GoplsRPCTrace bool
|
||||
// PPROF sets whether to start a profiling server on localhost:9999
|
||||
PPROF bool
|
||||
// HTTPDebug sets the HTTP endpoint to listen on. Leave empty for no web debug.
|
||||
HTTPDebug string
|
||||
}
|
||||
|
||||
func Run(stdin io.Reader, stdout, stderr io.Writer, args Arguments) (err error) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChan, os.Interrupt)
|
||||
defer func() {
|
||||
signal.Stop(signalChan)
|
||||
cancel()
|
||||
}()
|
||||
if args.PPROF {
|
||||
go func() {
|
||||
_ = http.ListenAndServe("localhost:9999", nil)
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-signalChan: // First signal, cancel context.
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
}
|
||||
<-signalChan // Second signal, hard exit.
|
||||
os.Exit(2)
|
||||
}()
|
||||
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
|
||||
if args.Log != "" {
|
||||
file, err := os.OpenFile(args.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open log file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create a new logger with a file writer
|
||||
log = slog.New(slog.NewJSONHandler(file, nil))
|
||||
log.Debug("Logging to file", slog.String("file", args.Log))
|
||||
}
|
||||
templStream := jsonrpc2.NewStream(newStdRwc(log, "templStream", stdout, stdin))
|
||||
return run(ctx, log, templStream, args)
|
||||
}
|
||||
|
||||
func run(ctx context.Context, log *slog.Logger, templStream jsonrpc2.Stream, args Arguments) (err error) {
|
||||
log.Info("lsp: starting up...")
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error("handled panic", slog.Any("recovered", r))
|
||||
}
|
||||
}()
|
||||
|
||||
log.Info("lsp: starting gopls...")
|
||||
rwc, err := pls.NewGopls(ctx, log, pls.Options{
|
||||
Log: args.GoplsLog,
|
||||
RPCTrace: args.GoplsRPCTrace,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error("failed to start gopls", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cache := proxy.NewSourceMapCache()
|
||||
diagnosticCache := proxy.NewDiagnosticCache()
|
||||
|
||||
log.Info("creating gopls client")
|
||||
clientProxy, clientInit := proxy.NewClient(log, cache, diagnosticCache)
|
||||
_, goplsConn, goplsServer := protocol.NewClient(ctx, clientProxy, jsonrpc2.NewStream(rwc), log)
|
||||
defer goplsConn.Close()
|
||||
|
||||
log.Info("creating proxy")
|
||||
// Create the proxy to sit between.
|
||||
serverProxy := proxy.NewServer(log, goplsServer, cache, diagnosticCache)
|
||||
|
||||
// Create templ server.
|
||||
log.Info("creating templ server")
|
||||
_, templConn, templClient := protocol.NewServer(context.Background(), serverProxy, templStream, log)
|
||||
defer templConn.Close()
|
||||
|
||||
// Allow both the server and the client to initiate outbound requests.
|
||||
clientInit(templClient)
|
||||
|
||||
// Start the web server if required.
|
||||
if args.HTTPDebug != "" {
|
||||
log.Info("starting debug http server", slog.String("addr", args.HTTPDebug))
|
||||
h := httpdebug.NewHandler(log, serverProxy)
|
||||
go func() {
|
||||
if err := http.ListenAndServe(args.HTTPDebug, h); err != nil {
|
||||
log.Error("web server failed", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Info("listening")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("context closed")
|
||||
case <-templConn.Done():
|
||||
log.Info("templConn closed")
|
||||
case <-goplsConn.Done():
|
||||
log.Info("goplsConn closed")
|
||||
}
|
||||
log.Info("shutdown complete")
|
||||
return
|
||||
}
|
124
templ/cmd/templ/lspcmd/pls/main.go
Normal file
124
templ/cmd/templ/lspcmd/pls/main.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package pls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Options for the gopls client.
|
||||
type Options struct {
|
||||
Log string
|
||||
RPCTrace bool
|
||||
}
|
||||
|
||||
// AsArguments converts the options into command line arguments for gopls.
|
||||
func (opts Options) AsArguments() []string {
|
||||
var args []string
|
||||
if opts.Log != "" {
|
||||
args = append(args, "-logfile", opts.Log)
|
||||
}
|
||||
if opts.RPCTrace {
|
||||
args = append(args, "-rpc.trace")
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func FindGopls() (location string, err error) {
|
||||
executableName := "gopls"
|
||||
if runtime.GOOS == "windows" {
|
||||
executableName = "gopls.exe"
|
||||
}
|
||||
|
||||
pathLocation, err := exec.LookPath(executableName)
|
||||
if err == nil {
|
||||
// Found on the path.
|
||||
return pathLocation, nil
|
||||
}
|
||||
// Unexpected error.
|
||||
if !errors.Is(err, exec.ErrNotFound) {
|
||||
return "", fmt.Errorf("unexpected error looking for gopls: %w", err)
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unexpected error looking for gopls: %w", err)
|
||||
}
|
||||
|
||||
// Probe standard locations.
|
||||
locations := []string{
|
||||
path.Join(home, "go", "bin", executableName),
|
||||
path.Join(home, ".local", "bin", executableName),
|
||||
}
|
||||
for _, location := range locations {
|
||||
_, err = os.Stat(location)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Found in a standard location.
|
||||
return location, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("cannot find gopls on the path (%q), in $HOME/go/bin or $HOME/.local/bin/gopls. You can install gopls with `go install golang.org/x/tools/gopls@latest`", os.Getenv("PATH"))
|
||||
}
|
||||
|
||||
// NewGopls starts gopls and opens up a jsonrpc2 connection to it.
|
||||
func NewGopls(ctx context.Context, log *slog.Logger, opts Options) (rwc io.ReadWriteCloser, err error) {
|
||||
location, err := FindGopls()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := exec.Command(location, opts.AsArguments()...)
|
||||
return newProcessReadWriteCloser(log, cmd)
|
||||
}
|
||||
|
||||
// newProcessReadWriteCloser creates a processReadWriteCloser to allow stdin/stdout to be used as
|
||||
// a JSON RPC 2.0 transport.
|
||||
func newProcessReadWriteCloser(logger *slog.Logger, cmd *exec.Cmd) (rwc processReadWriteCloser, err error) {
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rwc = processReadWriteCloser{
|
||||
in: stdin,
|
||||
out: stdout,
|
||||
}
|
||||
go func() {
|
||||
if err := cmd.Run(); err != nil {
|
||||
logger.Error("gopls command error", slog.Any("error", err))
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
type processReadWriteCloser struct {
|
||||
in io.WriteCloser
|
||||
out io.ReadCloser
|
||||
}
|
||||
|
||||
func (prwc processReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
return prwc.out.Read(p)
|
||||
}
|
||||
|
||||
func (prwc processReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return prwc.in.Write(p)
|
||||
}
|
||||
|
||||
func (prwc processReadWriteCloser) Close() error {
|
||||
errInClose := prwc.in.Close()
|
||||
errOutClose := prwc.out.Close()
|
||||
if errInClose != nil || errOutClose != nil {
|
||||
return fmt.Errorf("error closing process - in: %v, out: %v", errInClose, errOutClose)
|
||||
}
|
||||
return nil
|
||||
}
|
143
templ/cmd/templ/lspcmd/proxy/client.go
Normal file
143
templ/cmd/templ/lspcmd/proxy/client.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
lsp "github.com/a-h/templ/lsp/protocol"
|
||||
)
|
||||
|
||||
// Client is responsible for rewriting messages that are
|
||||
// originated from gopls, and are sent to the client.
|
||||
//
|
||||
// Since `gopls` is working on Go files, and this is the `templ` LSP,
|
||||
// the job of this code is to rewrite incoming requests to adjust the
|
||||
// file name from `*_templ.go` to `*.templ`, and to remap the char
|
||||
// positions where required.
|
||||
type Client struct {
|
||||
Log *slog.Logger
|
||||
Target lsp.Client
|
||||
SourceMapCache *SourceMapCache
|
||||
DiagnosticCache *DiagnosticCache
|
||||
}
|
||||
|
||||
func NewClient(log *slog.Logger, cache *SourceMapCache, diagnosticCache *DiagnosticCache) (c *Client, init func(lsp.Client)) {
|
||||
c = &Client{
|
||||
Log: log,
|
||||
SourceMapCache: cache,
|
||||
DiagnosticCache: diagnosticCache,
|
||||
}
|
||||
return c, func(target lsp.Client) {
|
||||
c.Target = target
|
||||
}
|
||||
}
|
||||
|
||||
func (p Client) Progress(ctx context.Context, params *lsp.ProgressParams) (err error) {
|
||||
p.Log.Info("client <- server: Progress")
|
||||
return p.Target.Progress(ctx, params)
|
||||
}
|
||||
func (p Client) WorkDoneProgressCreate(ctx context.Context, params *lsp.WorkDoneProgressCreateParams) (err error) {
|
||||
p.Log.Info("client <- server: WorkDoneProgressCreate")
|
||||
return p.Target.WorkDoneProgressCreate(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) LogMessage(ctx context.Context, params *lsp.LogMessageParams) (err error) {
|
||||
p.Log.Info("client <- server: LogMessage", slog.String("message", params.Message))
|
||||
return p.Target.LogMessage(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) PublishDiagnostics(ctx context.Context, params *lsp.PublishDiagnosticsParams) (err error) {
|
||||
p.Log.Info("client <- server: PublishDiagnostics")
|
||||
if strings.HasSuffix(string(params.URI), "go.mod") {
|
||||
p.Log.Info("client <- server: PublishDiagnostics: skipping go.mod diagnostics")
|
||||
return nil
|
||||
}
|
||||
// Log diagnostics.
|
||||
for i, diagnostic := range params.Diagnostics {
|
||||
p.Log.Info(fmt.Sprintf("client <- server: PublishDiagnostics: [%d]", i), slog.Any("diagnostic", diagnostic))
|
||||
}
|
||||
// Get the sourcemap from the cache.
|
||||
uri := strings.TrimSuffix(string(params.URI), "_templ.go") + ".templ"
|
||||
sourceMap, ok := p.SourceMapCache.Get(uri)
|
||||
if !ok {
|
||||
p.Log.Error("unable to complete because the sourcemap for the URI doesn't exist in the cache", slog.String("uri", uri))
|
||||
return fmt.Errorf("unable to complete because the sourcemap for %q doesn't exist in the cache, has the didOpen notification been sent yet?", uri)
|
||||
}
|
||||
params.URI = lsp.DocumentURI(uri)
|
||||
// Rewrite the positions.
|
||||
for i := 0; i < len(params.Diagnostics); i++ {
|
||||
item := params.Diagnostics[i]
|
||||
start, ok := sourceMap.SourcePositionFromTarget(item.Range.Start.Line, item.Range.Start.Character)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if item.Range.Start.Line == item.Range.End.Line {
|
||||
length := item.Range.End.Character - item.Range.Start.Character
|
||||
item.Range.Start.Line = start.Line
|
||||
item.Range.Start.Character = start.Col
|
||||
item.Range.End.Line = start.Line
|
||||
item.Range.End.Character = start.Col + length
|
||||
params.Diagnostics[i] = item
|
||||
p.Log.Info(fmt.Sprintf("diagnostic [%d] rewritten", i), slog.Any("diagnostic", item))
|
||||
continue
|
||||
}
|
||||
end, ok := sourceMap.SourcePositionFromTarget(item.Range.End.Line, item.Range.End.Character)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
item.Range.Start.Line = start.Line
|
||||
item.Range.Start.Character = start.Col
|
||||
item.Range.End.Line = end.Line
|
||||
item.Range.End.Character = end.Col
|
||||
params.Diagnostics[i] = item
|
||||
p.Log.Info(fmt.Sprintf("diagnostic [%d] rewritten", i), slog.Any("diagnostic", item))
|
||||
}
|
||||
params.Diagnostics = p.DiagnosticCache.AddTemplDiagnostics(uri, params.Diagnostics)
|
||||
err = p.Target.PublishDiagnostics(ctx, params)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p Client) ShowMessage(ctx context.Context, params *lsp.ShowMessageParams) (err error) {
|
||||
p.Log.Info("client <- server: ShowMessage", slog.String("message", params.Message))
|
||||
if strings.HasPrefix(params.Message, "Do not edit this file!") {
|
||||
return
|
||||
}
|
||||
return p.Target.ShowMessage(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) ShowMessageRequest(ctx context.Context, params *lsp.ShowMessageRequestParams) (result *lsp.MessageActionItem, err error) {
|
||||
p.Log.Info("client <- server: ShowMessageRequest", slog.String("message", params.Message))
|
||||
return p.Target.ShowMessageRequest(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) Telemetry(ctx context.Context, params any) (err error) {
|
||||
p.Log.Info("client <- server: Telemetry")
|
||||
return p.Target.Telemetry(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) RegisterCapability(ctx context.Context, params *lsp.RegistrationParams) (err error) {
|
||||
p.Log.Info("client <- server: RegisterCapability")
|
||||
return p.Target.RegisterCapability(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) UnregisterCapability(ctx context.Context, params *lsp.UnregistrationParams) (err error) {
|
||||
p.Log.Info("client <- server: UnregisterCapability")
|
||||
return p.Target.UnregisterCapability(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) ApplyEdit(ctx context.Context, params *lsp.ApplyWorkspaceEditParams) (result *lsp.ApplyWorkspaceEditResponse, err error) {
|
||||
p.Log.Info("client <- server: ApplyEdit")
|
||||
return p.Target.ApplyEdit(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) Configuration(ctx context.Context, params *lsp.ConfigurationParams) (result []any, err error) {
|
||||
p.Log.Info("client <- server: Configuration")
|
||||
return p.Target.Configuration(ctx, params)
|
||||
}
|
||||
|
||||
func (p Client) WorkspaceFolders(ctx context.Context) (result []lsp.WorkspaceFolder, err error) {
|
||||
p.Log.Info("client <- server: WorkspaceFolders")
|
||||
return p.Target.WorkspaceFolders(ctx)
|
||||
}
|
61
templ/cmd/templ/lspcmd/proxy/diagnosticcache.go
Normal file
61
templ/cmd/templ/lspcmd/proxy/diagnosticcache.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
lsp "github.com/a-h/templ/lsp/protocol"
|
||||
)
|
||||
|
||||
func NewDiagnosticCache() *DiagnosticCache {
|
||||
return &DiagnosticCache{
|
||||
m: &sync.Mutex{},
|
||||
cache: make(map[string]fileDiagnostic),
|
||||
}
|
||||
}
|
||||
|
||||
type fileDiagnostic struct {
|
||||
templDiagnostics []lsp.Diagnostic
|
||||
goplsDiagnostics []lsp.Diagnostic
|
||||
}
|
||||
|
||||
type DiagnosticCache struct {
|
||||
m *sync.Mutex
|
||||
cache map[string]fileDiagnostic
|
||||
}
|
||||
|
||||
func zeroLengthSliceIfNil(diags []lsp.Diagnostic) []lsp.Diagnostic {
|
||||
if diags == nil {
|
||||
return make([]lsp.Diagnostic, 0)
|
||||
}
|
||||
return diags
|
||||
}
|
||||
|
||||
func (dc *DiagnosticCache) AddTemplDiagnostics(uri string, goDiagnostics []lsp.Diagnostic) []lsp.Diagnostic {
|
||||
goDiagnostics = zeroLengthSliceIfNil(goDiagnostics)
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
diag := dc.cache[uri]
|
||||
diag.goplsDiagnostics = goDiagnostics
|
||||
diag.templDiagnostics = zeroLengthSliceIfNil(diag.templDiagnostics)
|
||||
dc.cache[uri] = diag
|
||||
return append(diag.templDiagnostics, goDiagnostics...)
|
||||
}
|
||||
|
||||
func (dc *DiagnosticCache) ClearTemplDiagnostics(uri string) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
diag := dc.cache[uri]
|
||||
diag.templDiagnostics = make([]lsp.Diagnostic, 0)
|
||||
dc.cache[uri] = diag
|
||||
}
|
||||
|
||||
func (dc *DiagnosticCache) AddGoDiagnostics(uri string, templDiagnostics []lsp.Diagnostic) []lsp.Diagnostic {
|
||||
templDiagnostics = zeroLengthSliceIfNil(templDiagnostics)
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
diag := dc.cache[uri]
|
||||
diag.templDiagnostics = templDiagnostics
|
||||
diag.goplsDiagnostics = zeroLengthSliceIfNil(diag.goplsDiagnostics)
|
||||
dc.cache[uri] = diag
|
||||
return append(diag.goplsDiagnostics, templDiagnostics...)
|
||||
}
|
215
templ/cmd/templ/lspcmd/proxy/documentcontents.go
Normal file
215
templ/cmd/templ/lspcmd/proxy/documentcontents.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
lsp "github.com/a-h/templ/lsp/protocol"
|
||||
)
|
||||
|
||||
// newDocumentContents creates a document content processing tool.
|
||||
func newDocumentContents(log *slog.Logger) *DocumentContents {
|
||||
return &DocumentContents{
|
||||
m: new(sync.Mutex),
|
||||
uriToContents: make(map[string]*Document),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type DocumentContents struct {
|
||||
m *sync.Mutex
|
||||
uriToContents map[string]*Document
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// Set the contents of a document.
|
||||
func (dc *DocumentContents) Set(uri string, d *Document) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
dc.uriToContents[uri] = d
|
||||
}
|
||||
|
||||
// Get the contents of a document.
|
||||
func (dc *DocumentContents) Get(uri string) (d *Document, ok bool) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
d, ok = dc.uriToContents[uri]
|
||||
return
|
||||
}
|
||||
|
||||
// Delete a document from memory.
|
||||
func (dc *DocumentContents) Delete(uri string) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
delete(dc.uriToContents, uri)
|
||||
}
|
||||
|
||||
func (dc *DocumentContents) URIs() (uris []string) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
uris = make([]string, len(dc.uriToContents))
|
||||
var i int
|
||||
for k := range dc.uriToContents {
|
||||
uris[i] = k
|
||||
i++
|
||||
}
|
||||
return uris
|
||||
}
|
||||
|
||||
// Apply changes to the document from the client, and return a list of change requests to send back to the client.
|
||||
func (dc *DocumentContents) Apply(uri string, changes []lsp.TextDocumentContentChangeEvent) (d *Document, err error) {
|
||||
dc.m.Lock()
|
||||
defer dc.m.Unlock()
|
||||
var ok bool
|
||||
d, ok = dc.uriToContents[uri]
|
||||
if !ok {
|
||||
err = fmt.Errorf("document not found")
|
||||
return
|
||||
}
|
||||
for _, change := range changes {
|
||||
d.Apply(change.Range, change.Text)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func NewDocument(log *slog.Logger, s string) *Document {
|
||||
return &Document{
|
||||
Log: log,
|
||||
Lines: strings.Split(s, "\n"),
|
||||
}
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
Log *slog.Logger
|
||||
Lines []string
|
||||
}
|
||||
|
||||
func (d *Document) LineLengths() (lens []int) {
|
||||
lens = make([]int, len(d.Lines))
|
||||
for i, l := range d.Lines {
|
||||
lens[i] = len(l)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Document) Len() (line, col int) {
|
||||
line = len(d.Lines)
|
||||
col = len(d.Lines[len(d.Lines)-1])
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Document) Overwrite(fromLine, fromCol, toLine, toCol int, lines []string) {
|
||||
suffix := d.Lines[toLine][toCol:]
|
||||
toLen := d.LineLengths()[toLine]
|
||||
d.Delete(fromLine, fromCol, toLine, toLen)
|
||||
lines[len(lines)-1] = lines[len(lines)-1] + suffix
|
||||
d.Insert(fromLine, fromCol, lines)
|
||||
}
|
||||
|
||||
func (d *Document) Insert(line, col int, lines []string) {
|
||||
prefix := d.Lines[line][:col]
|
||||
suffix := d.Lines[line][col:]
|
||||
lines[0] = prefix + lines[0]
|
||||
d.Lines[line] = lines[0]
|
||||
|
||||
if len(lines) > 1 {
|
||||
d.InsertLines(line+1, lines[1:])
|
||||
}
|
||||
|
||||
d.Lines[line+len(lines)-1] = lines[len(lines)-1] + suffix
|
||||
}
|
||||
|
||||
func (d *Document) InsertLines(i int, withLines []string) {
|
||||
d.Lines = append(d.Lines[:i], append(withLines, d.Lines[i:]...)...)
|
||||
}
|
||||
|
||||
func (d *Document) Delete(fromLine, fromCol, toLine, toCol int) {
|
||||
prefix := d.Lines[fromLine][:fromCol]
|
||||
suffix := d.Lines[toLine][toCol:]
|
||||
|
||||
// Delete intermediate lines.
|
||||
deleteFrom := fromLine
|
||||
deleteTo := fromLine + (toLine - fromLine)
|
||||
d.DeleteLines(deleteFrom, deleteTo)
|
||||
|
||||
// Merge the contents of the final line.
|
||||
d.Lines[fromLine] = prefix + suffix
|
||||
}
|
||||
|
||||
func (d *Document) DeleteLines(i, j int) {
|
||||
d.Lines = append(d.Lines[:i], d.Lines[j:]...)
|
||||
}
|
||||
|
||||
func (d *Document) String() string {
|
||||
return strings.Join(d.Lines, "\n")
|
||||
}
|
||||
|
||||
func (d *Document) Replace(with string) {
|
||||
d.Lines = strings.Split(with, "\n")
|
||||
}
|
||||
|
||||
func (d *Document) Apply(r *lsp.Range, with string) {
|
||||
withLines := strings.Split(with, "\n")
|
||||
d.normalize(r)
|
||||
if d.isWholeDocument(r) {
|
||||
d.Lines = withLines
|
||||
return
|
||||
}
|
||||
if d.isInsert(r, with) {
|
||||
d.Insert(int(r.Start.Line), int(r.Start.Character), withLines)
|
||||
return
|
||||
}
|
||||
if d.isDelete(r, with) {
|
||||
d.Delete(int(r.Start.Line), int(r.Start.Character), int(r.End.Line), int(r.End.Character))
|
||||
return
|
||||
}
|
||||
if d.isOverwrite(r, with) {
|
||||
d.Overwrite(int(r.Start.Line), int(r.Start.Character), int(r.End.Line), int(r.End.Character), withLines)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Document) normalize(r *lsp.Range) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
lens := d.LineLengths()
|
||||
if r.Start.Line >= uint32(len(lens)) {
|
||||
r.Start.Line = uint32(len(lens) - 1)
|
||||
r.Start.Character = uint32(lens[r.Start.Line])
|
||||
}
|
||||
if r.Start.Character > uint32(lens[r.Start.Line]) {
|
||||
r.Start.Character = uint32(lens[r.Start.Line])
|
||||
}
|
||||
if r.End.Line >= uint32(len(lens)) {
|
||||
r.End.Line = uint32(len(lens) - 1)
|
||||
r.End.Character = uint32(lens[r.End.Line])
|
||||
}
|
||||
if r.End.Character > uint32(lens[r.End.Line]) {
|
||||
r.End.Character = uint32(lens[r.End.Line])
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Document) isOverwrite(r *lsp.Range, with string) bool {
|
||||
return (r.End.Line != r.Start.Line || r.Start.Character != r.End.Character) && with != ""
|
||||
}
|
||||
|
||||
func (d *Document) isInsert(r *lsp.Range, with string) bool {
|
||||
return r.End.Line == r.Start.Line && r.Start.Character == r.End.Character && with != ""
|
||||
}
|
||||
|
||||
func (d *Document) isDelete(r *lsp.Range, with string) bool {
|
||||
return (r.End.Line != r.Start.Line || r.Start.Character != r.End.Character) && with == ""
|
||||
}
|
||||
|
||||
func (d *Document) isWholeDocument(r *lsp.Range) bool {
|
||||
if r == nil {
|
||||
return true
|
||||
}
|
||||
if r.Start.Line != 0 || r.Start.Character != 0 {
|
||||
return false
|
||||
}
|
||||
l, c := d.Len()
|
||||
return r.End.Line == uint32(l) || r.End.Character == uint32(c)
|
||||
}
|
571
templ/cmd/templ/lspcmd/proxy/documentcontents_test.go
Normal file
571
templ/cmd/templ/lspcmd/proxy/documentcontents_test.go
Normal file
@@ -0,0 +1,571 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
lsp "github.com/a-h/templ/lsp/protocol"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestDocument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
start string
|
||||
operations []func(d *Document)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Replace all content if the range is nil",
|
||||
start: "0\n1\n2",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(nil, "replaced")
|
||||
},
|
||||
},
|
||||
expected: "replaced",
|
||||
},
|
||||
{
|
||||
name: "If the range matches the length of the file, all of it is replaced",
|
||||
start: "0\n1\n2",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 1,
|
||||
},
|
||||
}, "replaced")
|
||||
},
|
||||
},
|
||||
expected: "replaced",
|
||||
},
|
||||
{
|
||||
name: "Can insert new text",
|
||||
start: ``,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
}, "abc")
|
||||
},
|
||||
},
|
||||
expected: "abc",
|
||||
},
|
||||
{
|
||||
name: "Can insert new text that ends with a newline",
|
||||
start: ``,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
}, "abc\n")
|
||||
},
|
||||
},
|
||||
expected: `abc
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Can insert a new line at the end of existing text",
|
||||
start: `abc
|
||||
`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 3,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 3,
|
||||
},
|
||||
}, "\n")
|
||||
},
|
||||
},
|
||||
expected: `abc
|
||||
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Can insert a word at the start of existing text",
|
||||
start: `bc`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
}, "a")
|
||||
},
|
||||
},
|
||||
expected: `abc`,
|
||||
},
|
||||
{
|
||||
name: "Can remove whole line",
|
||||
start: "0\n1\n2",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 0,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "0\n2",
|
||||
},
|
||||
{
|
||||
name: "Can remove line prefix",
|
||||
start: "abcdef",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 3,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "def",
|
||||
},
|
||||
{
|
||||
name: "Can remove line substring",
|
||||
start: "abcdef",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 2,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 3,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "abdef",
|
||||
},
|
||||
{
|
||||
name: "Can remove line suffix",
|
||||
start: "abcdef",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 4,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 6,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "abcd",
|
||||
},
|
||||
{
|
||||
name: "Can remove across lines",
|
||||
start: "0\n1\n22",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 1,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "0\n2",
|
||||
},
|
||||
{
|
||||
name: "Can remove part of two lines",
|
||||
start: "Line one\nLine two\nLine three",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 4,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 4,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "Line three",
|
||||
},
|
||||
{
|
||||
name: "Can remove all lines",
|
||||
start: "0\n1\n2",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 1,
|
||||
},
|
||||
}, "")
|
||||
},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Can replace line prefix",
|
||||
start: "012345",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 3,
|
||||
},
|
||||
}, "ABCDEFG")
|
||||
},
|
||||
},
|
||||
expected: "ABCDEFG345",
|
||||
},
|
||||
{
|
||||
name: "Can replace text across line boundaries",
|
||||
start: "Line one\nLine two\nLine three",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 4,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 4,
|
||||
},
|
||||
}, " one test\nNew Line 2\nNew line")
|
||||
},
|
||||
},
|
||||
expected: "Line one test\nNew Line 2\nNew line three",
|
||||
},
|
||||
{
|
||||
name: "Can add new line to end of single line",
|
||||
start: `a`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 1,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 1,
|
||||
},
|
||||
}, "\nb")
|
||||
},
|
||||
},
|
||||
expected: "a\nb",
|
||||
},
|
||||
{
|
||||
name: "Exceeding the col and line count rounds down to the end of the file",
|
||||
start: `a`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 200,
|
||||
Character: 600,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 300,
|
||||
Character: 1200,
|
||||
},
|
||||
}, "\nb")
|
||||
},
|
||||
},
|
||||
expected: "a\nb",
|
||||
},
|
||||
{
|
||||
name: "Can remove a line and add it back from the end of the previous line (insert)",
|
||||
start: "a\nb\nc",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
// Delete.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 0,
|
||||
},
|
||||
}, "")
|
||||
// Put it back.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 1,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 1,
|
||||
},
|
||||
}, "\nb")
|
||||
},
|
||||
},
|
||||
expected: "a\nb\nc",
|
||||
},
|
||||
{
|
||||
name: "Can remove a line and add it back from the end of the previous line (overwrite)",
|
||||
start: "a\nb\nc",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
// Delete.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 0,
|
||||
},
|
||||
}, "")
|
||||
// Put it back.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 1,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
}, "\nb\n")
|
||||
},
|
||||
},
|
||||
expected: "a\nb\nc",
|
||||
},
|
||||
{
|
||||
name: "Add new line with indent to the end of the line",
|
||||
// Based on log entry.
|
||||
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
|
||||
start: `package testcall
|
||||
|
||||
templ personTemplate(p person) {
|
||||
<div>
|
||||
<h1>{ p.name }</h1>
|
||||
</div>
|
||||
}
|
||||
`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 4,
|
||||
Character: 21,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 4,
|
||||
Character: 21,
|
||||
},
|
||||
}, "\n\t\t")
|
||||
},
|
||||
},
|
||||
expected: `package testcall
|
||||
|
||||
templ personTemplate(p person) {
|
||||
<div>
|
||||
<h1>{ p.name }</h1>
|
||||
|
||||
</div>
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Recreate error smaller",
|
||||
// Based on log entry.
|
||||
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
|
||||
start: "line1\n\t\tline2\nline3",
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
// Remove \t\tline2
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 0,
|
||||
},
|
||||
}, "")
|
||||
// Put it back.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 5,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
},
|
||||
"\n\t\tline2\n")
|
||||
},
|
||||
},
|
||||
expected: "line1\n\t\tline2\nline3",
|
||||
},
|
||||
{
|
||||
name: "Recreate error",
|
||||
// Based on log entry.
|
||||
// {"level":"info","ts":"2022-06-04T20:55:15+01:00","caller":"proxy/server.go:391","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":2},"contentChanges":[{"range":{"start":{"line":4,"character":21},"end":{"line":4,"character":21}},"text":"\n\t\t"}]}}
|
||||
start: ` <footer data-testid="footerTemplate">
|
||||
<div>© { fmt.Sprintf("%d", time.Now().Year()) }</div>
|
||||
</footer>
|
||||
}
|
||||
`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
// Remove <div>© { fmt.Sprintf("%d", time.Now().Year()) }</div>
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 2,
|
||||
Character: 0,
|
||||
},
|
||||
}, "")
|
||||
// Put it back.
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 38,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
},
|
||||
"\n\t\t<div>© { fmt.Sprintf(\"%d\", time.Now().Year()) }</div>\n")
|
||||
},
|
||||
},
|
||||
expected: ` <footer data-testid="footerTemplate">
|
||||
<div>© { fmt.Sprintf("%d", time.Now().Year()) }</div>
|
||||
</footer>
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Insert at start of line",
|
||||
// Based on log entry.
|
||||
// {"level":"info","ts":"2023-03-25T17:17:38Z","caller":"proxy/server.go:393","msg":"client -> server: DidChange","params":{"textDocument":{"uri":"file:///Users/adrian/github.com/a-h/templ/generator/test-call/template.templ","version":5},"contentChanges":[{"range":{"start":{"line":6,"character":0},"end":{"line":6,"character":0}},"text":"a"}]}}
|
||||
start: `b`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 0,
|
||||
Character: 0,
|
||||
},
|
||||
}, "a")
|
||||
},
|
||||
},
|
||||
expected: `ab`,
|
||||
},
|
||||
{
|
||||
name: "Insert full new line",
|
||||
start: `a
|
||||
c
|
||||
d`,
|
||||
operations: []func(d *Document){
|
||||
func(d *Document) {
|
||||
d.Apply(&lsp.Range{
|
||||
Start: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
End: lsp.Position{
|
||||
Line: 1,
|
||||
Character: 0,
|
||||
},
|
||||
}, "b\n")
|
||||
},
|
||||
},
|
||||
expected: `a
|
||||
b
|
||||
c
|
||||
d`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stderr, nil))
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := NewDocument(logger, tt.start)
|
||||
for _, f := range tt.operations {
|
||||
f(d)
|
||||
}
|
||||
actual := d.String()
|
||||
if diff := cmp.Diff(tt.expected, actual); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
293
templ/cmd/templ/lspcmd/proxy/import_test.go
Normal file
293
templ/cmd/templ/lspcmd/proxy/import_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestFindLastImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
templContents string
|
||||
packageName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "if there are no imports, add a single line import",
|
||||
templContents: `package main
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
packageName: "strings",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "if there is an existing single-line imports, add one at the end",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "if there are multiple existing single-line imports, add one at the end",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
packageName: "time",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
import "time"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "if there are existing multi-line imports, add one at the end",
|
||||
templContents: `package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
templ example() {
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after templates",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
templ example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after funcs in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
func example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
func example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after css expressions in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
css example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
css example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after script expressions in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
script example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
script example() {
|
||||
}
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after var expressions in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
var s string
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
var s string
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after const expressions in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
const s = "test"
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
const s = "test"
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "ignore imports that happen after type expressions in the file",
|
||||
templContents: `package main
|
||||
|
||||
import "strings"
|
||||
|
||||
type Value int
|
||||
|
||||
import "other"
|
||||
`,
|
||||
packageName: "fmt",
|
||||
expected: `package main
|
||||
|
||||
import "strings"
|
||||
import "fmt"
|
||||
|
||||
type Value int
|
||||
|
||||
import "other"
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
lines := strings.Split(test.templContents, "\n")
|
||||
imp := addImport(lines, fmt.Sprintf("%q", test.packageName))
|
||||
textWithoutNewline := strings.TrimSuffix(imp.Text, "\n")
|
||||
actualLines := append(lines[:imp.LineIndex], append([]string{textWithoutNewline}, lines[imp.LineIndex:]...)...)
|
||||
actual := strings.Join(actualLines, "\n")
|
||||
if diff := cmp.Diff(test.expected, actual); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPackageFromItemDetail(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: `"fmt"`,
|
||||
expected: `"fmt"`,
|
||||
},
|
||||
{
|
||||
input: `func(state fmt.State, verb rune) string (from "fmt")`,
|
||||
expected: `"fmt"`,
|
||||
},
|
||||
{
|
||||
input: `non matching`,
|
||||
expected: `non matching`,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
actual := getPackageFromItemDetail(test.input)
|
||||
if test.expected != actual {
|
||||
t.Errorf("expected %q, got %q", test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
24
templ/cmd/templ/lspcmd/proxy/rewrite.go
Normal file
24
templ/cmd/templ/lspcmd/proxy/rewrite.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
lsp "github.com/a-h/templ/lsp/protocol"
|
||||
)
|
||||
|
||||
func convertTemplToGoURI(templURI lsp.DocumentURI) (isTemplFile bool, goURI lsp.DocumentURI) {
|
||||
base, fileName := path.Split(string(templURI))
|
||||
if !strings.HasSuffix(fileName, ".templ") {
|
||||
return
|
||||
}
|
||||
return true, lsp.DocumentURI(base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go"))
|
||||
}
|
||||
|
||||
func convertTemplGoToTemplURI(goURI lsp.DocumentURI) (isTemplGoFile bool, templURI lsp.DocumentURI) {
|
||||
base, fileName := path.Split(string(goURI))
|
||||
if !strings.HasSuffix(fileName, "_templ.go") {
|
||||
return
|
||||
}
|
||||
return true, lsp.DocumentURI(base + (strings.TrimSuffix(fileName, "_templ.go") + ".templ"))
|
||||
}
|
1289
templ/cmd/templ/lspcmd/proxy/server.go
Normal file
1289
templ/cmd/templ/lspcmd/proxy/server.go
Normal file
File diff suppressed because it is too large
Load Diff
111
templ/cmd/templ/lspcmd/proxy/snippets.go
Normal file
111
templ/cmd/templ/lspcmd/proxy/snippets.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package proxy
|
||||
|
||||
import lsp "github.com/a-h/templ/lsp/protocol"
|
||||
|
||||
var htmlSnippets = []lsp.CompletionItem{
|
||||
{
|
||||
Label: "<?>",
|
||||
InsertText: `${1}>
|
||||
${0}
|
||||
</${1}>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "a",
|
||||
InsertText: `a href="${1:}">${2:}</a>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "button",
|
||||
InsertText: `button type="button" ${1:}>${2:}</button>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "div",
|
||||
InsertText: `div>
|
||||
${0}
|
||||
</div>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "p",
|
||||
InsertText: `p>
|
||||
${0}
|
||||
</p>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "head",
|
||||
InsertText: `head>
|
||||
${0}
|
||||
</head>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "body",
|
||||
InsertText: `body>
|
||||
${0}
|
||||
</body>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "title",
|
||||
InsertText: `title>${0}</title>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h1",
|
||||
InsertText: `h1>${0}</h1>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h2",
|
||||
InsertText: `h2>${0}</h2>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h3",
|
||||
InsertText: `h3>${0}</h3>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h4",
|
||||
InsertText: `h4>${0}</h4>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h5",
|
||||
InsertText: `h5>${0}</h5>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
{
|
||||
Label: "h6",
|
||||
InsertText: `h6>${0}</h6>`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
}
|
||||
|
||||
var snippet = []lsp.CompletionItem{
|
||||
{
|
||||
Label: "templ",
|
||||
InsertText: `templ ${2:TemplateName}() {
|
||||
${0}
|
||||
}`,
|
||||
Kind: lsp.CompletionItemKind(lsp.CompletionItemKindSnippet),
|
||||
InsertTextFormat: lsp.InsertTextFormatSnippet,
|
||||
},
|
||||
}
|
52
templ/cmd/templ/lspcmd/proxy/sourcemapcache.go
Normal file
52
templ/cmd/templ/lspcmd/proxy/sourcemapcache.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
)
|
||||
|
||||
// NewSourceMapCache creates a cache of .templ file URIs to the source map.
|
||||
func NewSourceMapCache() *SourceMapCache {
|
||||
return &SourceMapCache{
|
||||
m: new(sync.Mutex),
|
||||
uriToSourceMap: make(map[string]*parser.SourceMap),
|
||||
}
|
||||
}
|
||||
|
||||
// SourceMapCache is a cache of .templ file URIs to the source map.
|
||||
type SourceMapCache struct {
|
||||
m *sync.Mutex
|
||||
uriToSourceMap map[string]*parser.SourceMap
|
||||
}
|
||||
|
||||
func (fc *SourceMapCache) Set(uri string, m *parser.SourceMap) {
|
||||
fc.m.Lock()
|
||||
defer fc.m.Unlock()
|
||||
fc.uriToSourceMap[uri] = m
|
||||
}
|
||||
|
||||
func (fc *SourceMapCache) Get(uri string) (m *parser.SourceMap, ok bool) {
|
||||
fc.m.Lock()
|
||||
defer fc.m.Unlock()
|
||||
m, ok = fc.uriToSourceMap[uri]
|
||||
return
|
||||
}
|
||||
|
||||
func (fc *SourceMapCache) Delete(uri string) {
|
||||
fc.m.Lock()
|
||||
defer fc.m.Unlock()
|
||||
delete(fc.uriToSourceMap, uri)
|
||||
}
|
||||
|
||||
func (fc *SourceMapCache) URIs() (uris []string) {
|
||||
fc.m.Lock()
|
||||
defer fc.m.Unlock()
|
||||
uris = make([]string, len(fc.uriToSourceMap))
|
||||
var i int
|
||||
for k := range fc.uriToSourceMap {
|
||||
uris[i] = k
|
||||
i++
|
||||
}
|
||||
return uris
|
||||
}
|
50
templ/cmd/templ/lspcmd/stdrwc.go
Normal file
50
templ/cmd/templ/lspcmd/stdrwc.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package lspcmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// stdrwc (standard read/write closer) reads from stdin, and writes to stdout.
|
||||
func newStdRwc(log *slog.Logger, name string, w io.Writer, r io.Reader) stdrwc {
|
||||
return stdrwc{
|
||||
log: log,
|
||||
name: name,
|
||||
w: w,
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
type stdrwc struct {
|
||||
log *slog.Logger
|
||||
name string
|
||||
w io.Writer
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (s stdrwc) Read(p []byte) (int, error) {
|
||||
return s.r.Read(p)
|
||||
}
|
||||
|
||||
func (s stdrwc) Write(p []byte) (int, error) {
|
||||
return s.w.Write(p)
|
||||
}
|
||||
|
||||
func (s stdrwc) Close() error {
|
||||
s.log.Info("rwc: closing", slog.String("name", s.name))
|
||||
var errs []error
|
||||
if closer, isCloser := s.r.(io.Closer); isCloser {
|
||||
if err := closer.Close(); err != nil {
|
||||
s.log.Error("rwc: error closing reader", slog.String("name", s.name), slog.Any("error", err))
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if closer, isCloser := s.w.(io.Closer); isCloser {
|
||||
if err := closer.Close(); err != nil {
|
||||
s.log.Error("rwc: error closing writer", slog.String("name", s.name), slog.Any("error", err))
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
394
templ/cmd/templ/main.go
Normal file
394
templ/cmd/templ/main.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/cmd/templ/fmtcmd"
|
||||
"github.com/a-h/templ/cmd/templ/generatecmd"
|
||||
"github.com/a-h/templ/cmd/templ/infocmd"
|
||||
"github.com/a-h/templ/cmd/templ/lspcmd"
|
||||
"github.com/a-h/templ/cmd/templ/sloghandler"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
func main() {
|
||||
code := run(os.Stdin, os.Stdout, os.Stderr, os.Args)
|
||||
if code != 0 {
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
|
||||
const usageText = `usage: templ <command> [<args>...]
|
||||
|
||||
templ - build HTML UIs with Go
|
||||
|
||||
See docs at https://templ.guide
|
||||
|
||||
commands:
|
||||
generate Generates Go code from templ files
|
||||
fmt Formats templ files
|
||||
lsp Starts a language server for templ files
|
||||
info Displays information about the templ environment
|
||||
version Prints the version
|
||||
`
|
||||
|
||||
func run(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
|
||||
if len(args) < 2 {
|
||||
fmt.Fprint(stderr, usageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
switch args[1] {
|
||||
case "info":
|
||||
return infoCmd(stdout, stderr, args[2:])
|
||||
case "generate":
|
||||
return generateCmd(stdout, stderr, args[2:])
|
||||
case "fmt":
|
||||
return fmtCmd(stdin, stdout, stderr, args[2:])
|
||||
case "lsp":
|
||||
return lspCmd(stdin, stdout, stderr, args[2:])
|
||||
case "version", "--version":
|
||||
fmt.Fprintln(stdout, templ.Version())
|
||||
return 0
|
||||
case "help", "-help", "--help", "-h":
|
||||
fmt.Fprint(stdout, usageText)
|
||||
return 0
|
||||
}
|
||||
fmt.Fprint(stderr, usageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
|
||||
func newLogger(logLevel string, verbose bool, stderr io.Writer) *slog.Logger {
|
||||
if verbose {
|
||||
logLevel = "debug"
|
||||
}
|
||||
level := slog.LevelInfo.Level()
|
||||
switch logLevel {
|
||||
case "debug":
|
||||
level = slog.LevelDebug.Level()
|
||||
case "warn":
|
||||
level = slog.LevelWarn.Level()
|
||||
case "error":
|
||||
level = slog.LevelError.Level()
|
||||
}
|
||||
return slog.New(sloghandler.NewHandler(stderr, &slog.HandlerOptions{
|
||||
AddSource: logLevel == "debug",
|
||||
Level: level,
|
||||
}))
|
||||
}
|
||||
|
||||
const infoUsageText = `usage: templ info [<args>...]
|
||||
|
||||
Displays information about the templ environment.
|
||||
|
||||
Args:
|
||||
-json
|
||||
Output information in JSON format to stdout. (default false)
|
||||
-v
|
||||
Set log verbosity level to "debug". (default "info")
|
||||
-log-level
|
||||
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
|
||||
-help
|
||||
Print help and exit.
|
||||
`
|
||||
|
||||
func infoCmd(stdout, stderr io.Writer, args []string) (code int) {
|
||||
cmd := flag.NewFlagSet("diagnose", flag.ExitOnError)
|
||||
jsonFlag := cmd.Bool("json", false, "")
|
||||
verboseFlag := cmd.Bool("v", false, "")
|
||||
logLevelFlag := cmd.String("log-level", "info", "")
|
||||
helpFlag := cmd.Bool("help", false, "")
|
||||
err := cmd.Parse(args)
|
||||
if err != nil {
|
||||
fmt.Fprint(stderr, infoUsageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
if *helpFlag {
|
||||
fmt.Fprint(stdout, infoUsageText)
|
||||
return
|
||||
}
|
||||
|
||||
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChan, os.Interrupt)
|
||||
go func() {
|
||||
<-signalChan
|
||||
fmt.Fprintln(stderr, "Stopping...")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = infocmd.Run(ctx, log, stdout, infocmd.Arguments{
|
||||
JSON: *jsonFlag,
|
||||
})
|
||||
if err != nil {
|
||||
color.New(color.FgRed).Fprint(stderr, "(✗) ")
|
||||
fmt.Fprintln(stderr, "Command failed: "+err.Error())
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const generateUsageText = `usage: templ generate [<args>...]
|
||||
|
||||
Generates Go code from templ files.
|
||||
|
||||
Args:
|
||||
-path <path>
|
||||
Generates code for all files in path. (default .)
|
||||
-f <file>
|
||||
Optionally generates code for a single file, e.g. -f header.templ
|
||||
-stdout
|
||||
Prints to stdout instead of writing generated files to the filesystem.
|
||||
Only applicable when -f is used.
|
||||
-source-map-visualisations
|
||||
Set to true to generate HTML files to visualise the templ code and its corresponding Go code.
|
||||
-include-version
|
||||
Set to false to skip inclusion of the templ version in the generated code. (default true)
|
||||
-include-timestamp
|
||||
Set to true to include the current time in the generated code.
|
||||
-watch
|
||||
Set to true to watch the path for changes and regenerate code.
|
||||
-watch-pattern <regexp>
|
||||
Set the regexp pattern of files that will be watched for changes. (default: '(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)')
|
||||
-cmd <cmd>
|
||||
Set the command to run after generating code.
|
||||
-proxy
|
||||
Set the URL to proxy after generating code and executing the command.
|
||||
-proxyport
|
||||
The port the proxy will listen on. (default 7331)
|
||||
-proxybind
|
||||
The address the proxy will listen on. (default 127.0.0.1)
|
||||
-notify-proxy
|
||||
If present, the command will issue a reload event to the proxy 127.0.0.1:7331, or use proxyport and proxybind to specify a different address.
|
||||
-w
|
||||
Number of workers to use when generating code. (default runtime.NumCPUs)
|
||||
-lazy
|
||||
Only generate .go files if the source .templ file is newer.
|
||||
-pprof
|
||||
Port to run the pprof server on.
|
||||
-keep-orphaned-files
|
||||
Keeps orphaned generated templ files. (default false)
|
||||
-v
|
||||
Set log verbosity level to "debug". (default "info")
|
||||
-log-level
|
||||
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
|
||||
-help
|
||||
Print help and exit.
|
||||
|
||||
Examples:
|
||||
|
||||
Generate code for all files in the current directory and subdirectories:
|
||||
|
||||
templ generate
|
||||
|
||||
Generate code for a single file:
|
||||
|
||||
templ generate -f header.templ
|
||||
|
||||
Watch the current directory and subdirectories for changes and regenerate code:
|
||||
|
||||
templ generate -watch
|
||||
`
|
||||
|
||||
func generateCmd(stdout, stderr io.Writer, args []string) (code int) {
|
||||
cmd := flag.NewFlagSet("generate", flag.ExitOnError)
|
||||
fileNameFlag := cmd.String("f", "", "")
|
||||
pathFlag := cmd.String("path", ".", "")
|
||||
toStdoutFlag := cmd.Bool("stdout", false, "")
|
||||
sourceMapVisualisationsFlag := cmd.Bool("source-map-visualisations", false, "")
|
||||
includeVersionFlag := cmd.Bool("include-version", true, "")
|
||||
includeTimestampFlag := cmd.Bool("include-timestamp", false, "")
|
||||
watchFlag := cmd.Bool("watch", false, "")
|
||||
watchPatternFlag := cmd.String("watch-pattern", "(.+\\.go$)|(.+\\.templ$)|(.+_templ\\.txt$)", "")
|
||||
openBrowserFlag := cmd.Bool("open-browser", true, "")
|
||||
cmdFlag := cmd.String("cmd", "", "")
|
||||
proxyFlag := cmd.String("proxy", "", "")
|
||||
proxyPortFlag := cmd.Int("proxyport", 7331, "")
|
||||
proxyBindFlag := cmd.String("proxybind", "127.0.0.1", "")
|
||||
notifyProxyFlag := cmd.Bool("notify-proxy", false, "")
|
||||
workerCountFlag := cmd.Int("w", runtime.NumCPU(), "")
|
||||
pprofPortFlag := cmd.Int("pprof", 0, "")
|
||||
keepOrphanedFilesFlag := cmd.Bool("keep-orphaned-files", false, "")
|
||||
verboseFlag := cmd.Bool("v", false, "")
|
||||
logLevelFlag := cmd.String("log-level", "info", "")
|
||||
lazyFlag := cmd.Bool("lazy", false, "")
|
||||
helpFlag := cmd.Bool("help", false, "")
|
||||
err := cmd.Parse(args)
|
||||
if err != nil {
|
||||
fmt.Fprint(stderr, generateUsageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
if *helpFlag {
|
||||
fmt.Fprint(stdout, generateUsageText)
|
||||
return
|
||||
}
|
||||
|
||||
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChan, os.Interrupt)
|
||||
go func() {
|
||||
<-signalChan
|
||||
fmt.Fprintln(stderr, "Stopping...")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var fw generatecmd.FileWriterFunc
|
||||
if *toStdoutFlag {
|
||||
fw = generatecmd.WriterFileWriter(stdout)
|
||||
}
|
||||
|
||||
err = generatecmd.Run(ctx, log, generatecmd.Arguments{
|
||||
FileName: *fileNameFlag,
|
||||
Path: *pathFlag,
|
||||
FileWriter: fw,
|
||||
Watch: *watchFlag,
|
||||
WatchPattern: *watchPatternFlag,
|
||||
OpenBrowser: *openBrowserFlag,
|
||||
Command: *cmdFlag,
|
||||
Proxy: *proxyFlag,
|
||||
ProxyPort: *proxyPortFlag,
|
||||
ProxyBind: *proxyBindFlag,
|
||||
NotifyProxy: *notifyProxyFlag,
|
||||
WorkerCount: *workerCountFlag,
|
||||
GenerateSourceMapVisualisations: *sourceMapVisualisationsFlag,
|
||||
IncludeVersion: *includeVersionFlag,
|
||||
IncludeTimestamp: *includeTimestampFlag,
|
||||
PPROFPort: *pprofPortFlag,
|
||||
KeepOrphanedFiles: *keepOrphanedFilesFlag,
|
||||
Lazy: *lazyFlag,
|
||||
})
|
||||
if err != nil {
|
||||
color.New(color.FgRed).Fprint(stderr, "(✗) ")
|
||||
fmt.Fprintln(stderr, "Command failed: "+err.Error())
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const fmtUsageText = `usage: templ fmt [<args> ...]
|
||||
|
||||
Format all files in directory:
|
||||
|
||||
templ fmt .
|
||||
|
||||
Format stdin to stdout:
|
||||
|
||||
templ fmt < header.templ
|
||||
|
||||
Format file or directory to stdout:
|
||||
|
||||
templ fmt -stdout FILE
|
||||
|
||||
Args:
|
||||
-stdout
|
||||
Prints to stdout instead of in-place format
|
||||
-stdin-filepath
|
||||
Provides the formatter with filepath context when using -stdout.
|
||||
Required for organising imports.
|
||||
-v
|
||||
Set log verbosity level to "debug". (default "info")
|
||||
-log-level
|
||||
Set log verbosity level. (default "info", options: "debug", "info", "warn", "error")
|
||||
-w
|
||||
Number of workers to use when formatting code. (default runtime.NumCPUs).
|
||||
-fail
|
||||
Fails with exit code 1 if files are changed. (e.g. in CI)
|
||||
-help
|
||||
Print help and exit.
|
||||
`
|
||||
|
||||
func fmtCmd(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
|
||||
cmd := flag.NewFlagSet("fmt", flag.ExitOnError)
|
||||
helpFlag := cmd.Bool("help", false, "")
|
||||
workerCountFlag := cmd.Int("w", runtime.NumCPU(), "")
|
||||
verboseFlag := cmd.Bool("v", false, "")
|
||||
logLevelFlag := cmd.String("log-level", "info", "")
|
||||
failIfChanged := cmd.Bool("fail", false, "")
|
||||
stdoutFlag := cmd.Bool("stdout", false, "")
|
||||
stdinFilepath := cmd.String("stdin-filepath", "", "")
|
||||
err := cmd.Parse(args)
|
||||
if err != nil {
|
||||
fmt.Fprint(stderr, fmtUsageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
if *helpFlag {
|
||||
fmt.Fprint(stdout, fmtUsageText)
|
||||
return
|
||||
}
|
||||
|
||||
log := newLogger(*logLevelFlag, *verboseFlag, stderr)
|
||||
|
||||
err = fmtcmd.Run(log, stdin, stdout, fmtcmd.Arguments{
|
||||
ToStdout: *stdoutFlag,
|
||||
Files: cmd.Args(),
|
||||
WorkerCount: *workerCountFlag,
|
||||
StdinFilepath: *stdinFilepath,
|
||||
FailIfChanged: *failIfChanged,
|
||||
})
|
||||
if err != nil {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const lspUsageText = `usage: templ lsp [<args> ...]
|
||||
|
||||
Starts a language server for templ.
|
||||
|
||||
Args:
|
||||
-log string
|
||||
The file to log templ LSP output to, or leave empty to disable logging.
|
||||
-goplsLog string
|
||||
The file to log gopls output, or leave empty to disable logging.
|
||||
-goplsRPCTrace
|
||||
Set gopls to log input and output messages.
|
||||
-help
|
||||
Print help and exit.
|
||||
-pprof
|
||||
Enable pprof web server (default address is localhost:9999)
|
||||
-http string
|
||||
Enable http debug server by setting a listen address (e.g. localhost:7474)
|
||||
`
|
||||
|
||||
func lspCmd(stdin io.Reader, stdout, stderr io.Writer, args []string) (code int) {
|
||||
cmd := flag.NewFlagSet("lsp", flag.ExitOnError)
|
||||
logFlag := cmd.String("log", "", "")
|
||||
goplsLog := cmd.String("goplsLog", "", "")
|
||||
goplsRPCTrace := cmd.Bool("goplsRPCTrace", false, "")
|
||||
helpFlag := cmd.Bool("help", false, "")
|
||||
pprofFlag := cmd.Bool("pprof", false, "")
|
||||
httpDebugFlag := cmd.String("http", "", "")
|
||||
err := cmd.Parse(args)
|
||||
if err != nil {
|
||||
fmt.Fprint(stderr, lspUsageText)
|
||||
return 64 // EX_USAGE
|
||||
}
|
||||
if *helpFlag {
|
||||
fmt.Fprint(stdout, lspUsageText)
|
||||
return
|
||||
}
|
||||
|
||||
err = lspcmd.Run(stdin, stdout, stderr, lspcmd.Arguments{
|
||||
Log: *logFlag,
|
||||
GoplsLog: *goplsLog,
|
||||
GoplsRPCTrace: *goplsRPCTrace,
|
||||
PPROF: *pprofFlag,
|
||||
HTTPDebug: *httpDebugFlag,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Fprintln(stderr, err.Error())
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
102
templ/cmd/templ/main_test.go
Normal file
102
templ/cmd/templ/main_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestMain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedStdout string
|
||||
expectedStderr string
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "no args prints usage",
|
||||
args: []string{},
|
||||
expectedStderr: usageText,
|
||||
expectedCode: 64, // EX_USAGE
|
||||
},
|
||||
{
|
||||
name: `"templ help" prints help`,
|
||||
args: []string{"templ", "help"},
|
||||
expectedStdout: usageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ --help" prints help`,
|
||||
args: []string{"templ", "--help"},
|
||||
expectedStdout: usageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ version" prints version`,
|
||||
args: []string{"templ", "version"},
|
||||
expectedStdout: templ.Version() + "\n",
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ --version" prints version`,
|
||||
args: []string{"templ", "--version"},
|
||||
expectedStdout: templ.Version() + "\n",
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ fmt --help" prints usage`,
|
||||
args: []string{"templ", "fmt", "--help"},
|
||||
expectedStdout: fmtUsageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ generate --help" prints usage`,
|
||||
args: []string{"templ", "generate", "--help"},
|
||||
expectedStdout: generateUsageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ lsp --help" prints usage`,
|
||||
args: []string{"templ", "lsp", "--help"},
|
||||
expectedStdout: lspUsageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
{
|
||||
name: `"templ info --help" prints usage`,
|
||||
args: []string{"templ", "info", "--help"},
|
||||
expectedStdout: infoUsageText,
|
||||
expectedCode: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stdin := strings.NewReader("")
|
||||
stdout := bytes.NewBuffer(nil)
|
||||
stderr := bytes.NewBuffer(nil)
|
||||
actualCode := run(stdin, stdout, stderr, test.args)
|
||||
|
||||
if actualCode != test.expectedCode {
|
||||
t.Errorf("expected code %v, got %v", test.expectedCode, actualCode)
|
||||
}
|
||||
if diff := cmp.Diff(test.expectedStdout, stdout.String()); diff != "" {
|
||||
t.Error(diff)
|
||||
t.Error("expected stdout:")
|
||||
t.Error(test.expectedStdout)
|
||||
t.Error("actual stdout:")
|
||||
t.Error(stdout.String())
|
||||
}
|
||||
if diff := cmp.Diff(test.expectedStderr, stderr.String()); diff != "" {
|
||||
t.Error(diff)
|
||||
t.Error("expected stderr:")
|
||||
t.Error(test.expectedStderr)
|
||||
t.Error("actual stderr:")
|
||||
t.Error(stderr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
80
templ/cmd/templ/processor/processor.go
Normal file
80
templ/cmd/templ/processor/processor.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package processor
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
FileName string
|
||||
Duration time.Duration
|
||||
Error error
|
||||
ChangesMade bool
|
||||
}
|
||||
|
||||
func Process(dir string, f func(fileName string) (error, bool), workerCount int, results chan<- Result) {
|
||||
templates := make(chan string)
|
||||
go func() {
|
||||
defer close(templates)
|
||||
if err := FindTemplates(dir, templates); err != nil {
|
||||
results <- Result{Error: err}
|
||||
}
|
||||
}()
|
||||
ProcessChannel(templates, dir, f, workerCount, results)
|
||||
}
|
||||
|
||||
func shouldSkipDir(dir string) bool {
|
||||
if dir == "." {
|
||||
return false
|
||||
}
|
||||
if dir == "vendor" || dir == "node_modules" {
|
||||
return true
|
||||
}
|
||||
_, name := path.Split(dir)
|
||||
// These directories are ignored by the Go tool.
|
||||
if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func FindTemplates(srcPath string, output chan<- string) (err error) {
|
||||
return filepath.WalkDir(srcPath, func(currentPath string, info fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() && shouldSkipDir(currentPath) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(currentPath, ".templ") {
|
||||
output <- currentPath
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessChannel(templates <-chan string, dir string, f func(fileName string) (error, bool), workerCount int, results chan<- Result) {
|
||||
defer close(results)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workerCount)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sourceFileName := range templates {
|
||||
start := time.Now()
|
||||
outErr, outChanged := f(sourceFileName)
|
||||
results <- Result{
|
||||
FileName: sourceFileName,
|
||||
Error: outErr,
|
||||
Duration: time.Since(start),
|
||||
ChangesMade: outChanged,
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
19
templ/cmd/templ/processor/processor_test.go
Normal file
19
templ/cmd/templ/processor/processor_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package processor
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFindTemplates(t *testing.T) {
|
||||
t.Run("returns an error if the directory does not exist", func(t *testing.T) {
|
||||
output := make(chan string)
|
||||
err := FindTemplates("nonexistent", output)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, but got nil")
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
t.Fatalf("expected os.IsNotExist(err) to be true, but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
101
templ/cmd/templ/sloghandler/handler.go
Normal file
101
templ/cmd/templ/sloghandler/handler.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package sloghandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
var _ slog.Handler = &Handler{}
|
||||
|
||||
type Handler struct {
|
||||
h slog.Handler
|
||||
m *sync.Mutex
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
var levelToIcon = map[slog.Level]string{
|
||||
slog.LevelDebug: "(✓)",
|
||||
slog.LevelInfo: "(✓)",
|
||||
slog.LevelWarn: "(!)",
|
||||
slog.LevelError: "(✗)",
|
||||
}
|
||||
var levelToColor = map[slog.Level]*color.Color{
|
||||
slog.LevelDebug: color.New(color.FgCyan),
|
||||
slog.LevelInfo: color.New(color.FgGreen),
|
||||
slog.LevelWarn: color.New(color.FgYellow),
|
||||
slog.LevelError: color.New(color.FgRed),
|
||||
}
|
||||
|
||||
func NewHandler(w io.Writer, opts *slog.HandlerOptions) *Handler {
|
||||
if opts == nil {
|
||||
opts = &slog.HandlerOptions{}
|
||||
}
|
||||
return &Handler{
|
||||
w: w,
|
||||
h: slog.NewTextHandler(w, &slog.HandlerOptions{
|
||||
Level: opts.Level,
|
||||
AddSource: opts.AddSource,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if opts.ReplaceAttr != nil {
|
||||
a = opts.ReplaceAttr(groups, a)
|
||||
}
|
||||
if a.Key == slog.LevelKey {
|
||||
level, ok := levelToIcon[a.Value.Any().(slog.Level)]
|
||||
if !ok {
|
||||
level = a.Value.Any().(slog.Level).String()
|
||||
}
|
||||
a.Value = slog.StringValue(level)
|
||||
return a
|
||||
}
|
||||
if a.Key == slog.TimeKey {
|
||||
return slog.Attr{}
|
||||
}
|
||||
return a
|
||||
},
|
||||
}),
|
||||
m: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Enabled(ctx context.Context, level slog.Level) bool {
|
||||
return h.h.Enabled(ctx, level)
|
||||
}
|
||||
|
||||
func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
return &Handler{h: h.h.WithAttrs(attrs), w: h.w, m: h.m}
|
||||
}
|
||||
|
||||
func (h *Handler) WithGroup(name string) slog.Handler {
|
||||
return &Handler{h: h.h.WithGroup(name), w: h.w, m: h.m}
|
||||
}
|
||||
|
||||
var keyValueColor = color.New(color.Faint & color.FgBlack)
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, r slog.Record) (err error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(levelToColor[r.Level].Sprint(levelToIcon[r.Level]))
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(r.Message)
|
||||
|
||||
if r.NumAttrs() != 0 {
|
||||
sb.WriteString(" [")
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
sb.WriteString(keyValueColor.Sprintf(" %s=%s", a.Key, a.Value.String()))
|
||||
return true
|
||||
})
|
||||
sb.WriteString(" ]")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
h.m.Lock()
|
||||
defer h.m.Unlock()
|
||||
_, err = io.WriteString(h.w, sb.String())
|
||||
return err
|
||||
}
|
3
templ/cmd/templ/testproject/testdata/css-classes/classes.go
vendored
Normal file
3
templ/cmd/templ/testproject/testdata/css-classes/classes.go
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
package cssclasses
|
||||
|
||||
const Header = "header"
|
7
templ/cmd/templ/testproject/testdata/go.mod.embed
vendored
Normal file
7
templ/cmd/templ/testproject/testdata/go.mod.embed
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
module templ/testproject
|
||||
|
||||
go 1.23
|
||||
|
||||
require github.com/a-h/templ v0.2.513 // indirect
|
||||
|
||||
replace github.com/a-h/templ => {moduleRoot}
|
2
templ/cmd/templ/testproject/testdata/go.sum
vendored
Normal file
2
templ/cmd/templ/testproject/testdata/go.sum
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
33
templ/cmd/templ/testproject/testdata/main.go
vendored
Normal file
33
templ/cmd/templ/testproject/testdata/main.go
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
)
|
||||
|
||||
var flagPort = flag.Int("port", 0, "Set the HTTP listen port")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
if *flagPort == 0 {
|
||||
fmt.Println("missing port flag")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var count int
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
count++
|
||||
c := Page(count)
|
||||
templ.Handler(c).ServeHTTP(w, r)
|
||||
})
|
||||
err := http.ListenAndServe(fmt.Sprintf("localhost:%d", *flagPort), nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Error listening: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
5
templ/cmd/templ/testproject/testdata/remotechild.templ
vendored
Normal file
5
templ/cmd/templ/testproject/testdata/remotechild.templ
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
package main
|
||||
|
||||
templ Remote() {
|
||||
<p>This is remote content</p>
|
||||
}
|
40
templ/cmd/templ/testproject/testdata/remotechild_templ.go
vendored
Normal file
40
templ/cmd/templ/testproject/testdata/remotechild_templ.go
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package main
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
func Remote() templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<p>This is remote content</p>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
9
templ/cmd/templ/testproject/testdata/remoteparent.templ
vendored
Normal file
9
templ/cmd/templ/testproject/testdata/remoteparent.templ
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
package main
|
||||
|
||||
templ RemoteInclusionTest() {
|
||||
@Remote
|
||||
}
|
||||
|
||||
templ Remote2() {
|
||||
@Remote
|
||||
}
|
69
templ/cmd/templ/testproject/testdata/remoteparent_templ.go
vendored
Normal file
69
templ/cmd/templ/testproject/testdata/remoteparent_templ.go
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package main
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
func RemoteInclusionTest() templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = Remote.Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func Remote2() templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var2 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var2 == nil {
|
||||
templ_7745c5c3_Var2 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = Remote.Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
25
templ/cmd/templ/testproject/testdata/templates.templ
vendored
Normal file
25
templ/cmd/templ/testproject/testdata/templates.templ
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ Page(count int) {
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>templ test page</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Count</h1>
|
||||
<div data-testid="count">{ fmt.Sprintf("%d", count) }</div>
|
||||
<div data-testid="modification">Original</div>
|
||||
</body>
|
||||
</html>
|
||||
}
|
||||
|
||||
var nihao = "你好"
|
||||
|
||||
type Struct struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
var s = Struct{}
|
63
templ/cmd/templ/testproject/testdata/templates_templ.go
vendored
Normal file
63
templ/cmd/templ/testproject/testdata/templates_templ.go
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package main
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
import "fmt"
|
||||
|
||||
func Page(count int) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<!doctype html><html><head><title>templ test page</title></head><body><h1>Count</h1><div data-testid=\"count\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var2 string
|
||||
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(fmt.Sprintf("%d", count))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/testproject/testdata/templates.templ`, Line: 13, Col: 54}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div><div data-testid=\"modification\">Original</div></body></html>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var nihao = "你好"
|
||||
|
||||
type Struct struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
var s = Struct{}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
70
templ/cmd/templ/testproject/testproject.go
Normal file
70
templ/cmd/templ/testproject/testproject.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package testproject
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed testdata/*
|
||||
var testdata embed.FS
|
||||
|
||||
func Create(moduleRoot string) (dir string, err error) {
|
||||
dir, err = os.MkdirTemp("", "templ_test_*")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to make test dir: %w", err)
|
||||
}
|
||||
files, err := testdata.ReadDir("testdata")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
if err = os.MkdirAll(filepath.Join(dir, file.Name()), 0777); err != nil {
|
||||
return dir, fmt.Errorf("failed to create dir: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
src := filepath.Join("testdata", file.Name())
|
||||
data, err := testdata.ReadFile(src)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
target := filepath.Join(dir, file.Name())
|
||||
if file.Name() == "go.mod.embed" {
|
||||
data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot))
|
||||
target = filepath.Join(dir, "go.mod")
|
||||
}
|
||||
err = os.WriteFile(target, data, 0660)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
}
|
||||
files, err = testdata.ReadDir("testdata/css-classes")
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read embedded dir: %w", err)
|
||||
}
|
||||
for _, file := range files {
|
||||
src := filepath.Join("testdata", "css-classes", file.Name())
|
||||
data, err := testdata.ReadFile(src)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
target := filepath.Join(dir, "css-classes", file.Name())
|
||||
err = os.WriteFile(target, data, 0660)
|
||||
if err != nil {
|
||||
return dir, fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func MustReplaceLine(file string, line int, replacement string) string {
|
||||
lines := strings.Split(file, "\n")
|
||||
lines[line-1] = replacement
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
64
templ/cmd/templ/visualize/sourcemapvisualisation.templ
Normal file
64
templ/cmd/templ/visualize/sourcemapvisualisation.templ
Normal file
@@ -0,0 +1,64 @@
|
||||
package visualize
|
||||
|
||||
css row() {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
css column() {
|
||||
flex: 50%;
|
||||
overflow-y: scroll;
|
||||
max-height: 100vh;
|
||||
}
|
||||
|
||||
css code() {
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
templ combine(templFileName string, left, right templ.Component) {
|
||||
<html>
|
||||
<head>
|
||||
<title>{ templFileName }- Source Map Visualisation</title>
|
||||
<style type="text/css">
|
||||
.mapped { background-color: green }
|
||||
.highlighted { background-color: yellow }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>{ templFileName }</h1>
|
||||
<div class={ templ.Classes(row()) }>
|
||||
<div class={ templ.Classes(column(), code()) }>
|
||||
@left
|
||||
</div>
|
||||
<div class={ templ.Classes(column(), code()) }>
|
||||
@right
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
}
|
||||
|
||||
script highlight(sourceId, targetId string) {
|
||||
let items = document.getElementsByClassName(sourceId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.add("highlighted");
|
||||
}
|
||||
items = document.getElementsByClassName(targetId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.add("highlighted");
|
||||
}
|
||||
}
|
||||
|
||||
script removeHighlight(sourceId, targetId string) {
|
||||
let items = document.getElementsByClassName(sourceId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.remove("highlighted");
|
||||
}
|
||||
items = document.getElementsByClassName(targetId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.remove("highlighted");
|
||||
}
|
||||
}
|
||||
|
||||
templ mappedCharacter(s string, sourceID, targetID string) {
|
||||
<span class={ templ.Classes(templ.Class("mapped"), templ.Class(sourceID), templ.Class(targetID)) } onMouseOver={ highlight(sourceID, targetID) } onMouseOut={ removeHighlight(sourceID, targetID) }>{ s }</span>
|
||||
}
|
296
templ/cmd/templ/visualize/sourcemapvisualisation_templ.go
Normal file
296
templ/cmd/templ/visualize/sourcemapvisualisation_templ.go
Normal file
@@ -0,0 +1,296 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.833
|
||||
package visualize
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
func row() templ.CSSClass {
|
||||
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
|
||||
templ_7745c5c3_CSSBuilder.WriteString(`display:flex;`)
|
||||
templ_7745c5c3_CSSID := templ.CSSID(`row`, templ_7745c5c3_CSSBuilder.String())
|
||||
return templ.ComponentCSSClass{
|
||||
ID: templ_7745c5c3_CSSID,
|
||||
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
|
||||
}
|
||||
}
|
||||
|
||||
func column() templ.CSSClass {
|
||||
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
|
||||
templ_7745c5c3_CSSBuilder.WriteString(`flex:50%;`)
|
||||
templ_7745c5c3_CSSBuilder.WriteString(`overflow-y:scroll;`)
|
||||
templ_7745c5c3_CSSBuilder.WriteString(`max-height:100vh;`)
|
||||
templ_7745c5c3_CSSID := templ.CSSID(`column`, templ_7745c5c3_CSSBuilder.String())
|
||||
return templ.ComponentCSSClass{
|
||||
ID: templ_7745c5c3_CSSID,
|
||||
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
|
||||
}
|
||||
}
|
||||
|
||||
func code() templ.CSSClass {
|
||||
templ_7745c5c3_CSSBuilder := templruntime.GetBuilder()
|
||||
templ_7745c5c3_CSSBuilder.WriteString(`font-family:monospace;`)
|
||||
templ_7745c5c3_CSSID := templ.CSSID(`code`, templ_7745c5c3_CSSBuilder.String())
|
||||
return templ.ComponentCSSClass{
|
||||
ID: templ_7745c5c3_CSSID,
|
||||
Class: templ.SafeCSS(`.` + templ_7745c5c3_CSSID + `{` + templ_7745c5c3_CSSBuilder.String() + `}`),
|
||||
}
|
||||
}
|
||||
|
||||
func combine(templFileName string, left, right templ.Component) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<html><head><title>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var2 string
|
||||
templ_7745c5c3_Var2, templ_7745c5c3_Err = templ.JoinStringErrs(templFileName)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 20, Col: 25}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var2))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "- Source Map Visualisation</title><style type=\"text/css\">\n\t\t\t\t.mapped { background-color: green }\n\t\t\t\t.highlighted { background-color: yellow }\n\t\t\t</style></head><body><h1>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var3 string
|
||||
templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(templFileName)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 27, Col: 22}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</h1>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var4 = []any{templ.Classes(row())}
|
||||
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var4...)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "<div class=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var5 string
|
||||
templ_7745c5c3_Var5, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var4).String())
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var5))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var6 = []any{templ.Classes(column(), code())}
|
||||
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var6...)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "<div class=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var7 string
|
||||
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var6).String())
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = left.Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var8 = []any{templ.Classes(column(), code())}
|
||||
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var8...)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "<div class=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var9 string
|
||||
templ_7745c5c3_Var9, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var8).String())
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var9))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = right.Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "</div></div></body></html>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func highlight(sourceId, targetId string) templ.ComponentScript {
|
||||
return templ.ComponentScript{
|
||||
Name: `__templ_highlight_ae80`,
|
||||
Function: `function __templ_highlight_ae80(sourceId, targetId){let items = document.getElementsByClassName(sourceId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.add("highlighted");
|
||||
}
|
||||
items = document.getElementsByClassName(targetId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.add("highlighted");
|
||||
}
|
||||
}`,
|
||||
Call: templ.SafeScript(`__templ_highlight_ae80`, sourceId, targetId),
|
||||
CallInline: templ.SafeScriptInline(`__templ_highlight_ae80`, sourceId, targetId),
|
||||
}
|
||||
}
|
||||
|
||||
func removeHighlight(sourceId, targetId string) templ.ComponentScript {
|
||||
return templ.ComponentScript{
|
||||
Name: `__templ_removeHighlight_58f2`,
|
||||
Function: `function __templ_removeHighlight_58f2(sourceId, targetId){let items = document.getElementsByClassName(sourceId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.remove("highlighted");
|
||||
}
|
||||
items = document.getElementsByClassName(targetId);
|
||||
for(let i = 0; i < items.length; i ++) {
|
||||
items[i].classList.remove("highlighted");
|
||||
}
|
||||
}`,
|
||||
Call: templ.SafeScript(`__templ_removeHighlight_58f2`, sourceId, targetId),
|
||||
CallInline: templ.SafeScriptInline(`__templ_removeHighlight_58f2`, sourceId, targetId),
|
||||
}
|
||||
}
|
||||
|
||||
func mappedCharacter(s string, sourceID, targetID string) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var10 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var10 == nil {
|
||||
templ_7745c5c3_Var10 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
var templ_7745c5c3_Var11 = []any{templ.Classes(templ.Class("mapped"), templ.Class(sourceID), templ.Class(targetID))}
|
||||
templ_7745c5c3_Err = templ.RenderCSSItems(ctx, templ_7745c5c3_Buffer, templ_7745c5c3_Var11...)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templ.RenderScriptItems(ctx, templ_7745c5c3_Buffer, highlight(sourceID, targetID), removeHighlight(sourceID, targetID))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "<span class=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var12 string
|
||||
templ_7745c5c3_Var12, templ_7745c5c3_Err = templ.JoinStringErrs(templ.CSSClasses(templ_7745c5c3_Var11).String())
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 1, Col: 0}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var12))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "\" onMouseOver=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var13 templ.ComponentScript = highlight(sourceID, targetID)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ_7745c5c3_Var13.Call)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "\" onMouseOut=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var14 templ.ComponentScript = removeHighlight(sourceID, targetID)
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ_7745c5c3_Var14.Call)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var15 string
|
||||
templ_7745c5c3_Var15, templ_7745c5c3_Err = templ.JoinStringErrs(s)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templ/cmd/templ/visualize/sourcemapvisualisation.templ`, Line: 63, Col: 200}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var15))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "</span>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
87
templ/cmd/templ/visualize/types.go
Normal file
87
templ/cmd/templ/visualize/types.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package visualize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/a-h/templ"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
)
|
||||
|
||||
func HTML(templFileName string, templContents, goContents string, sourceMap *parser.SourceMap) templ.Component {
|
||||
tl := templLines{contents: string(templContents), sourceMap: sourceMap}
|
||||
gl := goLines{contents: string(goContents), sourceMap: sourceMap}
|
||||
return combine(templFileName, tl, gl)
|
||||
}
|
||||
|
||||
type templLines struct {
|
||||
contents string
|
||||
sourceMap *parser.SourceMap
|
||||
}
|
||||
|
||||
func (tl templLines) Render(ctx context.Context, w io.Writer) (err error) {
|
||||
templLines := strings.Split(tl.contents, "\n")
|
||||
for lineIndex, line := range templLines {
|
||||
if _, err = w.Write([]byte("<span>" + strconv.Itoa(lineIndex) + " </span>\n")); err != nil {
|
||||
return
|
||||
}
|
||||
for colIndex, c := range line {
|
||||
if tgt, ok := tl.sourceMap.TargetPositionFromSource(uint32(lineIndex), uint32(colIndex)); ok {
|
||||
sourceID := fmt.Sprintf("src_%d_%d", lineIndex, colIndex)
|
||||
targetID := fmt.Sprintf("tgt_%d_%d", tgt.Line, tgt.Col)
|
||||
if err := mappedCharacter(string(c), sourceID, targetID).Render(ctx, w); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
s := html.EscapeString(string(c))
|
||||
s = strings.ReplaceAll(s, "\t", " ")
|
||||
s = strings.ReplaceAll(s, " ", " ")
|
||||
if _, err := w.Write([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err = w.Write([]byte("\n<br/>\n")); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type goLines struct {
|
||||
contents string
|
||||
sourceMap *parser.SourceMap
|
||||
}
|
||||
|
||||
func (gl goLines) Render(ctx context.Context, w io.Writer) (err error) {
|
||||
templLines := strings.Split(gl.contents, "\n")
|
||||
for lineIndex, line := range templLines {
|
||||
if _, err = w.Write([]byte("<span>" + strconv.Itoa(lineIndex) + " </span>\n")); err != nil {
|
||||
return
|
||||
}
|
||||
for colIndex, c := range line {
|
||||
if src, ok := gl.sourceMap.SourcePositionFromTarget(uint32(lineIndex), uint32(colIndex)); ok {
|
||||
sourceID := fmt.Sprintf("src_%d_%d", src.Line, src.Col)
|
||||
targetID := fmt.Sprintf("tgt_%d_%d", lineIndex, colIndex)
|
||||
if err := mappedCharacter(string(c), sourceID, targetID).Render(ctx, w); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
s := html.EscapeString(string(c))
|
||||
s = strings.ReplaceAll(s, "\t", " ")
|
||||
s = strings.ReplaceAll(s, " ", " ")
|
||||
if _, err := w.Write([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, err = w.Write([]byte("\n<br/>\n")); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user