Changed: DB Params
This commit is contained in:
174
templ/cmd/templ/imports/process.go
Normal file
174
templ/cmd/templ/imports/process.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package imports
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/token"
|
||||
"path"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
goparser "go/parser"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/imports"
|
||||
|
||||
"github.com/a-h/templ/generator"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
)
|
||||
|
||||
var internalImports = []string{"github.com/a-h/templ", "github.com/a-h/templ/runtime"}
|
||||
|
||||
func convertTemplToGoURI(templURI string) (isTemplFile bool, goURI string) {
|
||||
base, fileName := path.Split(templURI)
|
||||
if !strings.HasSuffix(fileName, ".templ") {
|
||||
return
|
||||
}
|
||||
return true, base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go")
|
||||
}
|
||||
|
||||
var fset = token.NewFileSet()
|
||||
|
||||
func updateImports(name, src string) (updated []*ast.ImportSpec, err error) {
|
||||
// Apply auto imports.
|
||||
updatedGoCode, err := imports.Process(name, []byte(src), nil)
|
||||
if err != nil {
|
||||
return updated, fmt.Errorf("failed to process go code %q: %w", src, err)
|
||||
}
|
||||
// Get updated imports.
|
||||
gofile, err := goparser.ParseFile(fset, name, updatedGoCode, goparser.ImportsOnly)
|
||||
if err != nil {
|
||||
return updated, fmt.Errorf("failed to get imports from updated go code: %w", err)
|
||||
}
|
||||
for _, imp := range gofile.Imports {
|
||||
if !slices.Contains(internalImports, strings.Trim(imp.Path.Value, "\"")) {
|
||||
updated = append(updated, imp)
|
||||
}
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func Process(t parser.TemplateFile) (parser.TemplateFile, error) {
|
||||
if t.Filepath == "" {
|
||||
return t, nil
|
||||
}
|
||||
isTemplFile, fileName := convertTemplToGoURI(t.Filepath)
|
||||
if !isTemplFile {
|
||||
return t, fmt.Errorf("invalid filepath: %s", t.Filepath)
|
||||
}
|
||||
|
||||
// The first node always contains existing imports.
|
||||
// If there isn't one, create it.
|
||||
if len(t.Nodes) == 0 {
|
||||
t.Nodes = append(t.Nodes, parser.TemplateFileGoExpression{})
|
||||
}
|
||||
// If there is one, ensure it is a Go expression.
|
||||
if _, ok := t.Nodes[0].(parser.TemplateFileGoExpression); !ok {
|
||||
t.Nodes = append([]parser.TemplateFileNode{parser.TemplateFileGoExpression{}}, t.Nodes...)
|
||||
}
|
||||
|
||||
// Find all existing imports.
|
||||
importsNode := t.Nodes[0].(parser.TemplateFileGoExpression)
|
||||
|
||||
// Generate code.
|
||||
gw := bytes.NewBuffer(nil)
|
||||
var updatedImports []*ast.ImportSpec
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() (err error) {
|
||||
if _, err := generator.Generate(t, gw); err != nil {
|
||||
return fmt.Errorf("failed to generate go code: %w", err)
|
||||
}
|
||||
updatedImports, err = updateImports(fileName, gw.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get imports from generated go code: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
var firstGoNodeInTemplate *ast.File
|
||||
// Update the template with the imports.
|
||||
// Ensure that there is a Go expression to add the imports to as the first node.
|
||||
eg.Go(func() (err error) {
|
||||
firstGoNodeInTemplate, err = goparser.ParseFile(fset, fileName, t.Package.Expression.Value+"\n"+importsNode.Expression.Value, goparser.AllErrors|goparser.ParseComments)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse imports section: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Wait for completion of both parts.
|
||||
if err := eg.Wait(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
// Delete unused imports.
|
||||
for _, imp := range firstGoNodeInTemplate.Imports {
|
||||
if !containsImport(updatedImports, imp) {
|
||||
name, path, err := getImportDetails(imp)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
}
|
||||
// Add imports, if there are any to add.
|
||||
for _, imp := range updatedImports {
|
||||
if !containsImport(firstGoNodeInTemplate.Imports, imp) {
|
||||
name, path, err := getImportDetails(imp)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
}
|
||||
// Edge case: reinsert the import to use import syntax without parentheses.
|
||||
if len(firstGoNodeInTemplate.Imports) == 1 {
|
||||
name, path, err := getImportDetails(firstGoNodeInTemplate.Imports[0])
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
astutil.DeleteNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
astutil.AddNamedImport(fset, firstGoNodeInTemplate, name, path)
|
||||
}
|
||||
// Write out the Go code with the imports.
|
||||
updatedGoCode := new(strings.Builder)
|
||||
err := format.Node(updatedGoCode, fset, firstGoNodeInTemplate)
|
||||
if err != nil {
|
||||
return t, fmt.Errorf("failed to write updated go code: %w", err)
|
||||
}
|
||||
// Remove the package statement from the node, by cutting the first line of the file.
|
||||
importsNode.Expression.Value = strings.TrimSpace(strings.SplitN(updatedGoCode.String(), "\n", 2)[1])
|
||||
if len(updatedImports) == 0 && importsNode.Expression.Value == "" {
|
||||
t.Nodes = t.Nodes[1:]
|
||||
return t, nil
|
||||
}
|
||||
t.Nodes[0] = importsNode
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func getImportDetails(imp *ast.ImportSpec) (name, importPath string, err error) {
|
||||
if imp.Name != nil {
|
||||
name = imp.Name.Name
|
||||
}
|
||||
if imp.Path != nil {
|
||||
importPath, err = strconv.Unquote(imp.Path.Value)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unquote package path %s: %w", imp.Path.Value, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return name, importPath, nil
|
||||
}
|
||||
|
||||
func containsImport(imports []*ast.ImportSpec, spec *ast.ImportSpec) bool {
|
||||
for _, imp := range imports {
|
||||
if imp.Path.Value == spec.Path.Value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
154
templ/cmd/templ/imports/process_test.go
Normal file
154
templ/cmd/templ/imports/process_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package imports
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/a-h/templ/cmd/templ/testproject"
|
||||
"github.com/a-h/templ/parser/v2"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/tools/txtar"
|
||||
)
|
||||
|
||||
func TestFormatting(t *testing.T) {
|
||||
files, _ := filepath.Glob("testdata/*.txtar")
|
||||
if len(files) == 0 {
|
||||
t.Errorf("no test files found")
|
||||
}
|
||||
for _, file := range files {
|
||||
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||
a, err := txtar.ParseFile(file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse txtar file: %v", err)
|
||||
}
|
||||
if len(a.Files) != 2 {
|
||||
t.Fatalf("expected 2 files, got %d", len(a.Files))
|
||||
}
|
||||
template, err := parser.ParseString(clean(a.Files[0].Data))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse %v", err)
|
||||
}
|
||||
template.Filepath = a.Files[0].Name
|
||||
tf, err := Process(template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process file: %v", err)
|
||||
}
|
||||
expected := string(a.Files[1].Data)
|
||||
actual := new(strings.Builder)
|
||||
if err := tf.Write(actual); err != nil {
|
||||
t.Fatalf("failed to write template file: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, actual.String()); diff != "" {
|
||||
t.Errorf("%s:\n%s", file, diff)
|
||||
t.Errorf("expected:\n%s", showWhitespace(expected))
|
||||
t.Errorf("actual:\n%s", showWhitespace(actual.String()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func showWhitespace(s string) string {
|
||||
s = strings.ReplaceAll(s, "\n", "⏎\n")
|
||||
s = strings.ReplaceAll(s, "\t", "→")
|
||||
s = strings.ReplaceAll(s, " ", "·")
|
||||
return s
|
||||
}
|
||||
|
||||
func clean(b []byte) string {
|
||||
b = bytes.ReplaceAll(b, []byte("$\n"), []byte("\n"))
|
||||
b = bytes.TrimSuffix(b, []byte("\n"))
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func TestImport(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
assertions func(t *testing.T, updated string)
|
||||
}{
|
||||
{
|
||||
name: "un-named imports are removed",
|
||||
src: `package main
|
||||
|
||||
import "fmt"
|
||||
import "github.com/a-h/templ/cmd/templ/testproject/css-classes"
|
||||
|
||||
templ Page(count int) {
|
||||
{ fmt.Sprintf("%d", count) }
|
||||
{ cssclasses.Header }
|
||||
}
|
||||
`,
|
||||
assertions: func(t *testing.T, updated string) {
|
||||
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 0 {
|
||||
t.Errorf("expected un-named import to be removed, but got %d instance of it", count)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "named imports are retained",
|
||||
src: `package main
|
||||
|
||||
import "fmt"
|
||||
import cssclasses "github.com/a-h/templ/cmd/templ/testproject/css-classes"
|
||||
|
||||
templ Page(count int) {
|
||||
{ fmt.Sprintf("%d", count) }
|
||||
{ cssclasses.Header }
|
||||
}
|
||||
`,
|
||||
assertions: func(t *testing.T, updated string) {
|
||||
if count := strings.Count(updated, "cssclasses \"github.com/a-h/templ/cmd/templ/testproject/css-classes\""); count != 1 {
|
||||
t.Errorf("expected named import to be retained, got %d instances of it", count)
|
||||
}
|
||||
if count := strings.Count(updated, "github.com/a-h/templ/cmd/templ/testproject/css-classes"); count != 1 {
|
||||
t.Errorf("expected one import, got %d", count)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
// Create test project.
|
||||
dir, err := testproject.Create("github.com/a-h/templ/cmd/templ/testproject")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test project: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Load the templates.templ file.
|
||||
filePath := path.Join(dir, "templates.templ")
|
||||
err = os.WriteFile(filePath, []byte(test.src), 0660)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
// Parse the new file.
|
||||
template, err := parser.Parse(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse %v", err)
|
||||
}
|
||||
template.Filepath = filePath
|
||||
tf, err := Process(template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to process file: %v", err)
|
||||
}
|
||||
|
||||
// Write it back out after processing.
|
||||
buf := new(strings.Builder)
|
||||
if err := tf.Write(buf); err != nil {
|
||||
t.Fatalf("failed to write template file: %v", err)
|
||||
}
|
||||
|
||||
// Assert.
|
||||
test.assertions(t, buf.String())
|
||||
}
|
||||
}
|
12
templ/cmd/templ/imports/testdata/comments.txtar
vendored
Normal file
12
templ/cmd/templ/imports/testdata/comments.txtar
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
28
templ/cmd/templ/imports/testdata/commentsbeforepackage.txtar
vendored
Normal file
28
templ/cmd/templ/imports/testdata/commentsbeforepackage.txtar
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
-- fmt_templ.templ --
|
||||
// Comments before.
|
||||
/*
|
||||
Some more comments
|
||||
*/
|
||||
package test
|
||||
|
||||
templ test() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
// Comments before.
|
||||
/*
|
||||
Some more comments
|
||||
*/
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ test() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
14
templ/cmd/templ/imports/testdata/deleteimports.txtar
vendored
Normal file
14
templ/cmd/templ/imports/testdata/deleteimports.txtar
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import "strconv"
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
15
templ/cmd/templ/imports/testdata/extraspace.txtar
vendored
Normal file
15
templ/cmd/templ/imports/testdata/extraspace.txtar
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
const x = 123
|
||||
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
const x = 123
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
22
templ/cmd/templ/imports/testdata/groups.txtar
vendored
Normal file
22
templ/cmd/templ/imports/testdata/groups.txtar
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"fmt"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strings.Contains(strconv.Quote("Hello"), ""))
|
21
templ/cmd/templ/imports/testdata/groupsmanynewlines.txtar
vendored
Normal file
21
templ/cmd/templ/imports/testdata/groupsmanynewlines.txtar
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strconv.Quote("Hello"))
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var _, _ = fmt.Print(strconv.Quote("Hello"))
|
10
templ/cmd/templ/imports/testdata/header.txtar
vendored
Normal file
10
templ/cmd/templ/imports/testdata/header.txtar
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
var x = fmt.Sprintf("Hello")
|
19
templ/cmd/templ/imports/testdata/namedimportsadd.txtar
vendored
Normal file
19
templ/cmd/templ/imports/testdata/namedimportsadd.txtar
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf(sconv.Quote("Hello"))
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf(sconv.Quote("Hello"))
|
16
templ/cmd/templ/imports/testdata/namedimportsremoved.txtar
vendored
Normal file
16
templ/cmd/templ/imports/testdata/namedimportsremoved.txtar
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
sconv "strconv"
|
||||
)
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
||||
-- fmt_templ.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Comment on variable or function.
|
||||
var x = fmt.Sprintf("Hello")
|
12
templ/cmd/templ/imports/testdata/noimports.txtar
vendored
Normal file
12
templ/cmd/templ/imports/testdata/noimports.txtar
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
20
templ/cmd/templ/imports/testdata/noimportscode.txtar
vendored
Normal file
20
templ/cmd/templ/imports/testdata/noimportscode.txtar
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
func test() {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
func test() {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
templ Hello() {
|
||||
<div>Hello</div>
|
||||
}
|
14
templ/cmd/templ/imports/testdata/stringexp.txtar
vendored
Normal file
14
templ/cmd/templ/imports/testdata/stringexp.txtar
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello(name string) {
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import "fmt"
|
||||
|
||||
templ Hello(name string) {
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
}
|
21
templ/cmd/templ/imports/testdata/twoimports.txtar
vendored
Normal file
21
templ/cmd/templ/imports/testdata/twoimports.txtar
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
templ Hello(name string) {
|
||||
<div id={ strconv.Atoi("123") }>
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
</div>
|
||||
}
|
||||
-- fmt.templ --
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
templ Hello(name string) {
|
||||
<div id={ strconv.Atoi("123") }>
|
||||
{ fmt.Sprintf("Hello, %s!", name) }
|
||||
</div>
|
||||
}
|
Reference in New Issue
Block a user