learnlytics-go/templ/parser/v2/goexpression/parse.go
2025-03-20 12:35:13 +01:00

344 lines
8.2 KiB
Go

package goexpression
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"regexp"
"strings"
"unicode"
)
var (
ErrContainerFuncNotFound = errors.New("parser error: templ container function not found")
ErrExpectedNodeNotFound = errors.New("parser error: expected node not found")
)
var defaultRegexp = regexp.MustCompile(`^default\s*:`)
func Case(content string) (start, end int, err error) {
if !(strings.HasPrefix(content, "case ") || defaultRegexp.MatchString(content)) {
return 0, 0, ErrExpectedNodeNotFound
}
prefix := "switch {\n"
src := prefix + content
start, end, err = extract(src, func(body []ast.Stmt) (start, end int, err error) {
sw, ok := body[0].(*ast.SwitchStmt)
if !ok {
return 0, 0, ErrExpectedNodeNotFound
}
if sw.Body == nil || len(sw.Body.List) == 0 {
return 0, 0, ErrExpectedNodeNotFound
}
stmt, ok := sw.Body.List[0].(*ast.CaseClause)
if !ok {
return 0, 0, ErrExpectedNodeNotFound
}
start = int(stmt.Case) - 1
end = int(stmt.Colon)
return start, end, nil
})
if err != nil {
return 0, 0, err
}
// Since we added a `switch {` prefix, we need to remove it.
start -= len(prefix)
end -= len(prefix)
return start, end, nil
}
func If(content string) (start, end int, err error) {
if !strings.HasPrefix(content, "if") {
return 0, 0, ErrExpectedNodeNotFound
}
return extract(content, func(body []ast.Stmt) (start, end int, err error) {
stmt, ok := body[0].(*ast.IfStmt)
if !ok {
return 0, 0, ErrExpectedNodeNotFound
}
start = int(stmt.If) + len("if")
end = latestEnd(start, stmt.Init, stmt.Cond)
return start, end, nil
})
}
func For(content string) (start, end int, err error) {
if !strings.HasPrefix(content, "for") {
return 0, 0, ErrExpectedNodeNotFound
}
return extract(content, func(body []ast.Stmt) (start, end int, err error) {
stmt := body[0]
switch stmt := stmt.(type) {
case *ast.ForStmt:
start = int(stmt.For) + len("for")
end = latestEnd(start, stmt.Init, stmt.Cond, stmt.Post)
return start, end, nil
case *ast.RangeStmt:
start = int(stmt.For) + len("for")
end = latestEnd(start, stmt.Key, stmt.Value, stmt.X)
return start, end, nil
}
return 0, 0, ErrExpectedNodeNotFound
})
}
func Switch(content string) (start, end int, err error) {
if !strings.HasPrefix(content, "switch") {
return 0, 0, ErrExpectedNodeNotFound
}
return extract(content, func(body []ast.Stmt) (start, end int, err error) {
stmt := body[0]
switch stmt := stmt.(type) {
case *ast.SwitchStmt:
start = int(stmt.Switch) + len("switch")
end = latestEnd(start, stmt.Init, stmt.Tag)
return start, end, nil
case *ast.TypeSwitchStmt:
start = int(stmt.Switch) + len("switch")
end = latestEnd(start, stmt.Init, stmt.Assign)
return start, end, nil
}
return 0, 0, ErrExpectedNodeNotFound
})
}
func TemplExpression(src string) (start, end int, err error) {
var s scanner.Scanner
fset := token.NewFileSet()
file := fset.AddFile("", fset.Base(), len(src))
errorHandler := func(pos token.Position, msg string) {
err = fmt.Errorf("error parsing expression: %v", msg)
}
s.Init(file, []byte(src), errorHandler, scanner.ScanComments)
// Read chains of identifiers, e.g.:
// components.Variable
// components[0].Variable
// components["name"].Function()
// functionCall(withLots(), func() { return true })
ep := NewExpressionParser()
for {
pos, tok, lit := s.Scan()
stop, err := ep.Insert(pos, tok, lit)
if err != nil {
return 0, 0, err
}
if stop {
break
}
}
return 0, ep.End, nil
}
func Expression(src string) (start, end int, err error) {
var s scanner.Scanner
fset := token.NewFileSet()
file := fset.AddFile("", fset.Base(), len(src))
errorHandler := func(pos token.Position, msg string) {
err = fmt.Errorf("error parsing expression: %v", msg)
}
s.Init(file, []byte(src), errorHandler, scanner.ScanComments)
// Read chains of identifiers and constants up until RBRACE, e.g.:
// true
// 123.45 == true
// components.Variable
// components[0].Variable
// components["name"].Function()
// functionCall(withLots(), func() { return true })
// !true
parenDepth := 0
bracketDepth := 0
braceDepth := 0
loop:
for {
pos, tok, lit := s.Scan()
if tok == token.EOF {
break loop
}
switch tok {
case token.LPAREN: // (
parenDepth++
case token.RPAREN: // )
end = int(pos)
parenDepth--
case token.LBRACK: // [
bracketDepth++
case token.RBRACK: // ]
end = int(pos)
bracketDepth--
case token.LBRACE: // {
braceDepth++
case token.RBRACE: // }
braceDepth--
if braceDepth < 0 {
// We've hit the end of the expression.
break loop
}
end = int(pos)
case token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
end = int(pos) + len(lit) - 1
case token.SEMICOLON:
continue
case token.COMMENT:
end = int(pos) + len(lit) - 1
case token.ILLEGAL:
return 0, 0, fmt.Errorf("illegal token: %v", lit)
default:
end = int(pos) + len(tok.String()) - 1
}
}
return start, end, nil
}
func SliceArgs(content string) (expr string, err error) {
prefix := "package main\nvar templ_args = []any{"
src := prefix + content + "}"
node, parseErr := parser.ParseFile(token.NewFileSet(), "", src, parser.AllErrors)
if node == nil {
return expr, parseErr
}
var from, to int
inspectFirstNode(node, func(n ast.Node) bool {
decl, ok := n.(*ast.CompositeLit)
if !ok {
return true
}
from = int(decl.Lbrace)
to = int(decl.Rbrace) - 1
for _, e := range decl.Elts {
to = int(e.End()) - 1
}
if to > int(decl.Rbrace)-1 {
to = int(decl.Rbrace) - 1
}
betweenEndAndBrace := src[to : decl.Rbrace-1]
var hasCodeBetweenEndAndBrace bool
for _, r := range betweenEndAndBrace {
if !unicode.IsSpace(r) {
hasCodeBetweenEndAndBrace = true
break
}
}
if hasCodeBetweenEndAndBrace {
to = int(decl.Rbrace) - 1
}
return false
})
return src[from:to], err
}
// Func returns the Go code up to the opening brace of the function body.
func Func(content string) (name, expr string, err error) {
prefix := "package main\n"
src := prefix + content
node, parseErr := parser.ParseFile(token.NewFileSet(), "", src, parser.AllErrors)
if node == nil {
return name, expr, parseErr
}
inspectFirstNode(node, func(n ast.Node) bool {
// Find the first function declaration.
fn, ok := n.(*ast.FuncDecl)
if !ok {
return true
}
start := int(fn.Pos()) + len("func")
end := fn.Type.Params.End() - 1
if len(src) < int(end) {
err = errors.New("parser error: function identifier")
return false
}
expr = strings.Clone(src[start:end])
name = fn.Name.Name
return false
})
return name, expr, err
}
func latestEnd(start int, nodes ...ast.Node) (end int) {
end = start
for _, n := range nodes {
if n == nil {
continue
}
if int(n.End())-1 > end {
end = int(n.End()) - 1
}
}
return end
}
func inspectFirstNode(node ast.Node, f func(ast.Node) bool) {
var stop bool
ast.Inspect(node, func(n ast.Node) bool {
if stop {
return true
}
if f(n) {
return true
}
stop = true
return false
})
}
// Extract a Go expression from the content.
// The Go expression starts at "start" and ends at "end".
// The reader should skip until "length" to pass over the expression and into the next
// logical block.
type Extractor func(body []ast.Stmt) (start, end int, err error)
func extract(content string, extractor Extractor) (start, end int, err error) {
prefix := "package main\nfunc templ_container() {\n"
src := prefix + content
node, parseErr := parser.ParseFile(token.NewFileSet(), "", src, parser.AllErrors)
if node == nil {
return 0, 0, parseErr
}
var found bool
inspectFirstNode(node, func(n ast.Node) bool {
// Find the "templ_container" function.
fn, ok := n.(*ast.FuncDecl)
if !ok {
return true
}
if fn.Name == nil || fn.Name.Name != "templ_container" {
err = ErrContainerFuncNotFound
return false
}
if fn.Body == nil || len(fn.Body.List) == 0 {
err = ErrExpectedNodeNotFound
return false
}
found = true
start, end, err = extractor(fn.Body.List)
return false
})
if !found {
return 0, 0, ErrExpectedNodeNotFound
}
start -= len(prefix)
end -= len(prefix)
if end > len(content) {
end = len(content)
}
if start > end {
start = end
}
return start, end, err
}