diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 333e34978..0775280e8 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1933,6 +1933,14 @@ public class CustomAgentConfig /// [JsonPropertyName("skills")] public IList? Skills { get; set; } + + /// + /// Model identifier for this agent (e.g. "claude-haiku-4.5"). + /// When set, the runtime will attempt to use this model for the agent, + /// falling back to the parent session model if unavailable. + /// + [JsonPropertyName("model")] + public string? Model { get; set; } } /// diff --git a/dotnet/test/Unit/CloneTests.cs b/dotnet/test/Unit/CloneTests.cs index 07cbde3de..dd94d05ae 100644 --- a/dotnet/test/Unit/CloneTests.cs +++ b/dotnet/test/Unit/CloneTests.cs @@ -95,7 +95,7 @@ public void SessionConfig_Clone_CopiesAllProperties() EnableSessionTelemetry = false, IncludeSubAgentStreamingEvents = false, McpServers = new Dictionary { ["server1"] = new McpStdioServerConfig { Command = "echo" } }, - CustomAgents = [new CustomAgentConfig { Name = "agent1" }], + CustomAgents = [new CustomAgentConfig { Name = "agent1", Model = "claude-haiku-4.5" }], Agent = "agent1", DefaultAgent = new DefaultAgentConfig { ExcludedTools = ["hidden-tool"] }, SkillDirectories = ["/skills"], @@ -120,6 +120,7 @@ public void SessionConfig_Clone_CopiesAllProperties() Assert.Equal(original.IncludeSubAgentStreamingEvents, clone.IncludeSubAgentStreamingEvents); Assert.Equal(original.McpServers.Count, clone.McpServers!.Count); Assert.Equal(original.CustomAgents.Count, clone.CustomAgents!.Count); + Assert.Equal(original.CustomAgents[0].Model, clone.CustomAgents[0].Model); Assert.Equal(original.Agent, clone.Agent); Assert.Equal(original.DefaultAgent!.ExcludedTools, clone.DefaultAgent!.ExcludedTools); Assert.Equal(original.SkillDirectories, clone.SkillDirectories); diff --git a/go/types.go b/go/types.go index 566e54f0f..562019e59 100644 --- a/go/types.go +++ b/go/types.go @@ -532,6 +532,10 @@ type CustomAgentConfig struct { Infer *bool `json:"infer,omitempty"` // Skills is the list of skill names to preload into this agent's context at startup (opt-in; omit for none) Skills []string `json:"skills,omitempty"` + // Model is the model identifier for this agent (e.g. "claude-haiku-4.5"). + // When set, the runtime will attempt to use this model for the agent, + // falling back to the parent session model if unavailable. + Model string `json:"model,omitempty"` } // DefaultAgentConfig configures the default agent (the built-in agent that handles turns when no custom agent is selected). diff --git a/go/types_test.go b/go/types_test.go index d24e6342f..2d80d206c 100644 --- a/go/types_test.go +++ b/go/types_test.go @@ -216,3 +216,49 @@ func TestProviderConfig_JSONOmitsUnsetTokenFields(t *testing.T) { } } } + +func TestCustomAgentConfig_JSONIncludesModel(t *testing.T) { + cfg := CustomAgentConfig{ + Name: "model-agent", + Prompt: "You are a model agent.", + Model: "claude-haiku-4.5", + } + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal CustomAgentConfig: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal CustomAgentConfig: %v", err) + } + + if decoded["model"] != "claude-haiku-4.5" { + t.Errorf("expected model 'claude-haiku-4.5', got %v", decoded["model"]) + } + if decoded["name"] != "model-agent" { + t.Errorf("expected name 'model-agent', got %v", decoded["name"]) + } +} + +func TestCustomAgentConfig_JSONOmitsModelWhenEmpty(t *testing.T) { + cfg := CustomAgentConfig{ + Name: "no-model-agent", + Prompt: "You are an agent without a model.", + } + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal CustomAgentConfig: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("failed to unmarshal CustomAgentConfig: %v", err) + } + + if _, present := decoded["model"]; present { + t.Errorf("expected model to be omitted when empty, got %v", decoded["model"]) + } +} diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index c28b67a89..200363fdc 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1213,6 +1213,12 @@ export interface CustomAgentConfig { * When omitted, no skills are injected (opt-in model). */ skills?: string[]; + /** + * Model identifier for this agent (e.g. "claude-haiku-4.5"). + * When set, the runtime will attempt to use this model for the agent, + * falling back to the parent session model if unavailable. + */ + model?: string; } /** diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index f33046bbc..69c851b7e 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -868,6 +868,29 @@ describe("CopilotClient", () => { expect(payload.customAgents).toEqual([expect.objectContaining({ name: "test-agent" })]); }); + it("forwards custom agent model in session.create request", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi.spyOn((client as any).connection!, "sendRequest"); + await client.createSession({ + onPermissionRequest: approveAll, + customAgents: [ + { + name: "model-agent", + prompt: "You are a model agent.", + model: "claude-haiku-4.5", + }, + ], + }); + + const payload = spy.mock.calls.find((c) => c[0] === "session.create")![1] as any; + expect(payload.customAgents).toEqual([ + expect.objectContaining({ name: "model-agent", model: "claude-haiku-4.5" }), + ]); + }); + it("forwards agent in session.resume request", async () => { const client = new CopilotClient(); await client.start(); diff --git a/python/copilot/client.py b/python/copilot/client.py index 4b265f6d5..3e2b367e5 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -2512,6 +2512,8 @@ def _convert_custom_agent_to_wire_format( wire_agent["infer"] = agent["infer"] if "skills" in agent: wire_agent["skills"] = agent["skills"] + if "model" in agent: + wire_agent["model"] = agent["model"] return wire_agent def _convert_default_agent_to_wire_format( diff --git a/python/copilot/session.py b/python/copilot/session.py index b682d22d7..efe89ce48 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -818,6 +818,8 @@ class CustomAgentConfig(TypedDict, total=False): infer: NotRequired[bool] # Whether agent is available for model inference # Skill names to preload into this agent's context at startup (opt-in; omit for none) skills: NotRequired[list[str]] + # Model identifier (e.g. "claude-haiku-4.5"); runtime falls back to parent model if unavailable + model: NotRequired[str] class DefaultAgentConfig(TypedDict, total=False): diff --git a/python/test_client.py b/python/test_client.py index 26de29287..64ad1a074 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -982,3 +982,34 @@ async def test_aexit_calls_disconnect(self): with patch.object(session, "disconnect", new_callable=AsyncMock) as mock_disconnect: await session.__aexit__(None, None, None) mock_disconnect.assert_awaited_once() + + +class TestCustomAgentWireFormat: + def test_model_field_is_forwarded_in_wire_format(self): + """The model key in CustomAgentConfig should appear as 'model' in the wire payload.""" + from copilot.client import CopilotClient + from copilot.session import CustomAgentConfig + + client = CopilotClient.__new__(CopilotClient) + agent: CustomAgentConfig = { + "name": "model-agent", + "prompt": "You are a model agent.", + "model": "claude-haiku-4.5", + } + wire = client._convert_custom_agent_to_wire_format(agent) + assert wire["model"] == "claude-haiku-4.5" + assert wire["name"] == "model-agent" + assert wire["prompt"] == "You are a model agent." + + def test_model_field_is_omitted_when_absent(self): + """When model is not set, it should not appear in the wire payload.""" + from copilot.client import CopilotClient + from copilot.session import CustomAgentConfig + + client = CopilotClient.__new__(CopilotClient) + agent: CustomAgentConfig = { + "name": "no-model-agent", + "prompt": "You are an agent without a model.", + } + wire = client._convert_custom_agent_to_wire_format(agent) + assert "model" not in wire diff --git a/rust/src/types.rs b/rust/src/types.rs index 68850dbbf..3c6e88746 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -526,6 +526,12 @@ pub struct CustomAgentConfig { /// Skill names to preload into this agent's context at startup. #[serde(default, skip_serializing_if = "Option::is_none")] pub skills: Option>, + /// Model identifier for this agent (e.g. `"claude-haiku-4.5"`). + /// + /// When set, the runtime will attempt to use this model for the agent, + /// falling back to the parent session model if unavailable. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, } impl CustomAgentConfig { @@ -587,6 +593,12 @@ impl CustomAgentConfig { self.skills = Some(skills.into_iter().map(Into::into).collect()); self } + + /// Set the model identifier for this agent. + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); + self + } } /// Configures the default (built-in) agent that handles turns when no @@ -3196,6 +3208,31 @@ mod tests { assert!(tool.skip_permission); } + #[test] + fn custom_agent_config_builder_with_model() { + let agent = CustomAgentConfig::new("my-agent", "You are helpful.") + .with_model("claude-haiku-4.5") + .with_display_name("My Agent"); + assert_eq!(agent.name, "my-agent"); + assert_eq!(agent.model.as_deref(), Some("claude-haiku-4.5")); + assert_eq!(agent.display_name.as_deref(), Some("My Agent")); + } + + #[test] + fn custom_agent_config_serializes_model() { + let agent = CustomAgentConfig::new("model-agent", "prompt").with_model("claude-haiku-4.5"); + let wire = serde_json::to_value(&agent).unwrap(); + assert_eq!(wire["model"], "claude-haiku-4.5"); + assert_eq!(wire["name"], "model-agent"); + } + + #[test] + fn custom_agent_config_omits_model_when_none() { + let agent = CustomAgentConfig::new("no-model-agent", "prompt"); + let wire = serde_json::to_value(&agent).unwrap(); + assert!(wire.get("model").is_none()); + } + #[test] fn tool_with_parameters_handles_non_object_value() { let tool = Tool::new("noop").with_parameters(json!(null));