learnlytics-go/templ/lsp/jsonrpc2/jsonrpc2_test.go
2025-03-20 12:35:13 +01:00

172 lines
3.6 KiB
Go

// SPDX-FileCopyrightText: 2021 The Go Language Server Authors
// SPDX-License-Identifier: BSD-3-Clause
package jsonrpc2_test
import (
"bytes"
"context"
"fmt"
"io"
"net"
"path"
"reflect"
"testing"
"encoding/json"
"github.com/a-h/templ/lsp/jsonrpc2"
)
const (
methodNoArgs = "no_args"
methodOneString = "one_string"
methodOneNumber = "one_number"
methodJoin = "join"
)
type callTest struct {
method string
params any
expect any
}
var callTests = []callTest{
{
method: methodNoArgs,
params: nil,
expect: true,
},
{
method: methodOneString,
params: "fish",
expect: "got:fish",
},
{
method: methodOneNumber,
params: 10,
expect: "got:10",
},
{
method: methodJoin,
params: []string{"a", "b", "c"},
expect: "a/b/c",
},
// TODO: expand the test cases
}
func (test *callTest) newResults() any {
switch e := test.expect.(type) {
case []any:
var r []any
for _, v := range e {
r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
}
return r
case nil:
return nil
default:
return reflect.New(reflect.TypeOf(test.expect)).Interface()
}
}
func (test *callTest) verifyResults(t *testing.T, results any) {
t.Helper()
if results == nil {
return
}
val := reflect.Indirect(reflect.ValueOf(results)).Interface()
if !reflect.DeepEqual(val, test.expect) {
t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect)
}
}
func TestRequest(t *testing.T) {
ctx := context.Background()
a, b, done := prepare(ctx, t)
defer done()
for _, test := range callTests {
t.Run(test.method, func(t *testing.T) {
results := test.newResults()
if _, err := a.Call(ctx, test.method, test.params, results); err != nil {
t.Fatalf("%v:Call failed: %v", test.method, err)
}
test.verifyResults(t, results)
if _, err := b.Call(ctx, test.method, test.params, results); err != nil {
t.Fatalf("%v:Call failed: %v", test.method, err)
}
test.verifyResults(t, results)
})
}
}
func prepare(ctx context.Context, t *testing.T) (a, b jsonrpc2.Conn, done func()) {
t.Helper()
// make a wait group that can be used to wait for the system to shut down
aPipe, bPipe := net.Pipe()
a = run(ctx, aPipe)
b = run(ctx, bPipe)
done = func() {
a.Close()
b.Close()
<-a.Done()
<-b.Done()
}
return a, b, done
}
func run(ctx context.Context, nc io.ReadWriteCloser) jsonrpc2.Conn {
stream := jsonrpc2.NewStream(nc)
conn := jsonrpc2.NewConn(stream)
conn.Go(ctx, testHandler())
return conn
}
func testHandler() jsonrpc2.Handler {
return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error {
switch req.Method() {
case methodNoArgs:
if len(req.Params()) > 0 {
return reply(ctx, nil, fmt.Errorf("expected no params: %w", jsonrpc2.ErrInvalidParams))
}
return reply(ctx, true, nil)
case methodOneString:
var v string
dec := json.NewDecoder(bytes.NewReader(req.Params()))
if err := dec.Decode(&v); err != nil {
return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err))
}
return reply(ctx, "got:"+v, nil)
case methodOneNumber:
var v int
dec := json.NewDecoder(bytes.NewReader(req.Params()))
if err := dec.Decode(&v); err != nil {
return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err))
}
return reply(ctx, fmt.Sprintf("got:%d", v), nil)
case methodJoin:
var v []string
dec := json.NewDecoder(bytes.NewReader(req.Params()))
if err := dec.Decode(&v); err != nil {
return reply(ctx, nil, fmt.Errorf("%s: %w", jsonrpc2.ErrParse, err))
}
return reply(ctx, path.Join(v...), nil)
default:
return jsonrpc2.MethodNotFoundHandler(ctx, reply, req)
}
}
}