Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1902,16 +1902,20 @@ func TestPointerArgEquivalence(t *testing.T) {
type input struct {
In string `json:",omitempty"`
}
inputSchema := json.RawMessage(`{"type":"object","properties":{"In":{"type":"string"}},"additionalProperties":false}`)
type output struct {
Out string
}
outputSchema := json.RawMessage(`{"type":"object","required":["Out"],"properties":{"Out":{"type":"string"}},"additionalProperties":false}`)
cs, _, cleanup := basicConnection(t, func(s *Server) {
// Add two equivalent tools, one of which operates in the 'pointer' realm,
// the other of which does not.
// Add three equivalent tools:
// - one operates on pointers, with inferred schemas
// - one operates on pointers, with user-provided schemas
// - one operates on non-pointers
//
// We handle a few different types of results, to assert they behave the
// same in all cases.
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) {
handlePointers := func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) {
switch in.In {
case "":
return nil, nil, fmt.Errorf("must provide input")
Expand All @@ -1924,7 +1928,13 @@ func TestPointerArgEquivalence(t *testing.T) {
default:
panic("unreachable")
}
})
}
AddTool(s, &Tool{Name: "pointer-inferred"}, handlePointers)
AddTool(s, &Tool{
Name: "pointer-provided",
InputSchema: inputSchema,
OutputSchema: outputSchema,
}, handlePointers)
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) {
switch in.In {
case "":
Expand All @@ -1947,50 +1957,51 @@ func TestPointerArgEquivalence(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got, want := len(tools.Tools), 2; got != want {
if got, want := len(tools.Tools), 3; got != want {
t.Fatalf("got %d tools, want %d", got, want)
}
t0 := tools.Tools[0]
t1 := tools.Tools[1]

// First, check that the tool schemas don't differ.
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}

// Then, check that we handle empty input equivalently.
for _, args := range []any{nil, struct{}{}} {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
if err != nil {
t.Fatal(err)
t0 := tools.Tools[0]
for _, t1 := range tools.Tools[1:] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t1 is a confusing name here since it is index 1 or 2.

tn might be slightly better.

maybe it would be even clearer if we iterated over adjacent pairs in this array, then we could use names like i and j and avoid giving explicit indices.

// First, check that the tool schemas don't differ.
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}
}

// Then, check that we handle different types of output equivalently.
for _, in := range []string{"nil", "empty", "ok"} {
t.Run(in, func(t *testing.T) {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
// Then, check that we handle empty input equivalently.
for _, args := range []any{nil, struct{}{}} {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
}
})
}

// Then, check that we handle different types of output equivalently.
for _, in := range []string{"nil", "empty", "ok"} {
t.Run(in, func(t *testing.T) {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
}
})
}
}
}

Expand Down
17 changes: 5 additions & 12 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,21 +359,14 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
// Pointers are treated equivalently to non-pointers when deriving the schema.
// If an indirection occurred to derive the schema, a non-nil zero value is
// returned to be used in place of the typed nil zero value.
//
// Note that if sfield already holds a schema, zero will be nil even if T is a
// pointer: if the user provided the schema, they may have intentionally
// derived it from the pointer type, and handling of zero values is up to them.
//
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
// should have a jsonschema.Zero(schema) helper?
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some more documentation around sfield and rfield, and what exactly this function is doing? E.g., it seems like sfield is an inpput, and rfield is an output, in addition to the 2 return variables. I think it might be beneficial to document the type of sfield since it is any but I think this is an any expected to unmarshall into a valid jsonschema object (cf Tool.OutputSchema).

I am also not familiar with the sfield and rfield names but I could be missing some conventions around this.

I could be missing some context around how this function is used, or conventions around these names.

rt := reflect.TypeFor[T]()
if rt.Kind() == reflect.Pointer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm on the fence, should we put a comment here explaining the importance of setting rt and zero when sfield != nil? As in, the change in this PR is (I think) making it so we never return a typed nil for pointer elements an rather return a element-typed zero. Is there a good reference as to why we should never return typed nils? From the comment in the PR it sounds like this is always a bug to return typed nils, but why?

rt = rt.Elem()
zero = reflect.Zero(rt).Interface()
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat related to my previous comment, should we check if our handling of zero/nil for pointer types is "interfering" with the user provided schema? It sounds like actually the user should never be including the ability to return typed nils in their schema, so maybe this is something we can check at validation time?

its also OK if the answer is "no" and this is too much added complexity, I just wanted to double check if there's any corner cases we could easily identify and flag for the user to make their UX better

var internalSchema *jsonschema.Schema
if *sfield == nil {
rt := reflect.TypeFor[T]()
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
zero = reflect.Zero(rt).Interface()
}
// TODO: we should be able to pass nil opts here.
internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
if err == nil {
Expand Down