diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 4e8715bd5..cffaae8b0 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -58,6 +58,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable /// private const int MinProtocolVersion = 3; private static readonly TimeSpan s_stderrPumpShutdownTimeout = TimeSpan.FromSeconds(5); + private static readonly TimeSpan s_runtimeShutdownTimeout = TimeSpan.FromSeconds(10); /// /// Provides a thread-safe collection of active Copilot sessions, indexed by session identifier. @@ -291,7 +292,7 @@ async Task StartCoreAsync(CancellationToken ct) if (connection is not null) { - await CleanupConnectionAsync(connection, errors: null); + await CleanupConnectionAsync(connection, errors: null, gracefulRuntimeShutdown: false); } else if (cliProcess is not null) { @@ -312,6 +313,7 @@ async Task StartCoreAsync(CancellationToken ct) /// This method performs graceful cleanup: /// /// Closes all active sessions (releases in-memory resources) + /// Requests runtime shutdown for SDK-owned CLI processes /// Closes the JSON-RPC connection /// Terminates the CLI server process (if spawned by this client) /// @@ -346,7 +348,7 @@ public async Task StopAsync() _sessions.Clear(); - await CleanupConnectionAsync(errors); + await CleanupConnectionAsync(errors, gracefulRuntimeShutdown: true); ThrowErrors(errors); } @@ -378,7 +380,7 @@ public async Task ForceStopAsync() _sessions.Clear(); var errors = new List(); - await CleanupConnectionAsync(errors); + await CleanupConnectionAsync(errors, gracefulRuntimeShutdown: false); ThrowErrors(errors); } @@ -398,7 +400,7 @@ private static void ThrowErrors(List? errors) } } - private async Task CleanupConnectionAsync(List? errors) + private async Task CleanupConnectionAsync(List? errors, bool gracefulRuntimeShutdown) { var connectionTask = _connectionTask; if (connectionTask is null) @@ -419,11 +421,36 @@ private async Task CleanupConnectionAsync(List? errors) return; } - await CleanupConnectionAsync(ctx, errors); + await CleanupConnectionAsync(ctx, errors, gracefulRuntimeShutdown); } - private async Task CleanupConnectionAsync(Connection ctx, List? errors) + private async Task CleanupConnectionAsync(Connection ctx, List? errors, bool gracefulRuntimeShutdown) { + var runtimeShutdownCompleted = false; + if (gracefulRuntimeShutdown && ctx.CliProcess is not null) + { + var runtimeShutdownTimestamp = Stopwatch.GetTimestamp(); + try + { + using var cancellation = new CancellationTokenSource(s_runtimeShutdownTimeout); + await ctx.Server.Runtime.ShutdownAsync(cancellation.Token); + runtimeShutdownCompleted = true; + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.StopAsync runtime shutdown complete. Elapsed={Elapsed}", + runtimeShutdownTimestamp); + } + catch (Exception ex) when (ex is OperationCanceledException + or InvalidOperationException + or ObjectDisposedException + or IOException + or SocketException) + { + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, ex, + "CopilotClient.StopAsync runtime shutdown failed. Elapsed={Elapsed}", + runtimeShutdownTimestamp); + } + } + try { ctx.Rpc.Dispose(); } catch (Exception ex) { AddCleanupError(errors, ex, _logger); } @@ -439,11 +466,11 @@ private async Task CleanupConnectionAsync(Connection ctx, List? error if (ctx.CliProcess is { } childProcess) { - await CleanupCliProcessAsync(childProcess, ctx.StderrPump, errors, _logger); + await CleanupCliProcessAsync(childProcess, ctx.StderrPump, errors, _logger, runtimeShutdownCompleted); } } - private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, List? errors, ILogger? logger) + private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, List? errors, ILogger? logger, bool waitForGracefulExit = false) { stderrPump?.Cancel(); @@ -451,10 +478,50 @@ private static async Task CleanupCliProcessAsync(Process childProcess, ProcessSt { if (!childProcess.HasExited) { + if (waitForGracefulExit) + { + var shutdownWaitTimestamp = Stopwatch.GetTimestamp(); + try + { + await childProcess.WaitForExitAsync().WaitAsync(s_runtimeShutdownTimeout); + } + catch (TimeoutException ex) + { + if (logger is not null) + { + LoggingHelpers.LogTiming(logger, LogLevel.Debug, ex, + "Timed out waiting for runtime process to exit after graceful shutdown. Elapsed={Elapsed}, Timeout={Timeout}", + shutdownWaitTimestamp, + s_runtimeShutdownTimeout); + } + } + } + + if (childProcess.HasExited) + { + return; + } + childProcess.Kill(entireProcessTree: true); // Kill is asynchronous; wait for the root CLI process to exit so cleanup callers // do not observe StopAsync/DisposeAsync completion while it is still tearing down. - await childProcess.WaitForExitAsync(); + var killWaitTimestamp = Stopwatch.GetTimestamp(); + try + { + await childProcess.WaitForExitAsync().WaitAsync(s_runtimeShutdownTimeout); + } + catch (TimeoutException ex) + { + if (logger is not null) + { + LoggingHelpers.LogTiming(logger, LogLevel.Debug, ex, + "Timed out waiting for runtime process to exit after kill. Elapsed={Elapsed}, Timeout={Timeout}", + killWaitTimestamp, + s_runtimeShutdownTimeout); + } + + AddCleanupError(errors, ex, logger); + } } } catch (Exception ex) @@ -2002,9 +2069,10 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", setupTimestamp); - _serverRpc = new ServerRpc(rpc); + var connection = new Connection(rpc, cliProcess, networkStream, stderrPump); + _serverRpc = connection.Server; - return new Connection(rpc, cliProcess, networkStream, stderrPump); + return connection; } catch { @@ -2208,6 +2276,7 @@ private class Connection( { public Process? CliProcess => cliProcess; public JsonRpc Rpc => rpc; + public ServerRpc Server => field ?? Interlocked.CompareExchange(ref field, new(rpc), null) ?? field; public NetworkStream? NetworkStream => networkStream; public ProcessStderrPump? StderrPump => stderrPump; public StringBuilder? StderrBuffer => stderrPump?.Buffer; diff --git a/dotnet/test/E2E/TelemetryExportE2ETests.cs b/dotnet/test/E2E/TelemetryExportE2ETests.cs index ceec2326e..22ed5663d 100644 --- a/dotnet/test/E2E/TelemetryExportE2ETests.cs +++ b/dotnet/test/E2E/TelemetryExportE2ETests.cs @@ -47,10 +47,7 @@ public async Task Should_Export_File_Telemetry_For_Sdk_Interactions() await session.DisposeAsync(); await client.StopAsync(); - var entries = await ReadTelemetryEntriesAsync( - telemetryPath, - entries => entries.Any(entry => GetTypeName(entry) == "span" && - GetStringAttribute(entry, "gen_ai.operation.name") == "invoke_agent")); + var entries = await ReadTelemetryEntriesAsync(telemetryPath); var spans = entries.Where(entry => GetTypeName(entry) == "span").ToList(); Assert.NotEmpty(spans); @@ -89,46 +86,23 @@ public async Task Should_Export_File_Telemetry_For_Sdk_Interactions() static string EchoTelemetryMarker(string value) => value; } - private static async Task> ReadTelemetryEntriesAsync( - string path, - Func, bool> isComplete) + private static async Task> ReadTelemetryEntriesAsync(string path) { - IReadOnlyList entries = []; - await TestHelper.WaitForConditionAsync( - async () => - { - entries = await ReadTelemetryEntriesOnceAsync(path); - return entries.Count > 0 && isComplete(entries); - }, - timeout: TimeSpan.FromSeconds(30), - timeoutMessage: $"Timed out waiting for telemetry records in '{path}'.", - transientExceptionFilter: exception => TestHelper.IsTransientFileSystemException(exception) || exception is JsonException); - - return entries; - - static async Task> ReadTelemetryEntriesOnceAsync(string path) + var entries = new List(); + using var stream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read); + using var reader = new StreamReader(stream); + while (await reader.ReadLineAsync() is { } line) { - if (!File.Exists(path) || new FileInfo(path).Length == 0) + if (string.IsNullOrWhiteSpace(line)) { - return []; + continue; } - var entries = new List(); - using var stream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.ReadWrite | FileShare.Delete); - using var reader = new StreamReader(stream); - while (await reader.ReadLineAsync() is { } line) - { - if (string.IsNullOrWhiteSpace(line)) - { - continue; - } - - using var document = JsonDocument.Parse(line); - entries.Add(document.RootElement.Clone()); - } - - return entries; + using var document = JsonDocument.Parse(line); + entries.Add(document.RootElement.Clone()); } + + return entries; } private static string? GetTraceId(JsonElement entry) => GetStringProperty(entry, "traceId"); diff --git a/dotnet/test/Unit/ClientSessionLifetimeTests.cs b/dotnet/test/Unit/ClientSessionLifetimeTests.cs index c52148a03..2c11c7d6b 100644 --- a/dotnet/test/Unit/ClientSessionLifetimeTests.cs +++ b/dotnet/test/Unit/ClientSessionLifetimeTests.cs @@ -5,6 +5,7 @@ #if NET8_0_OR_GREATER using System.Net; using System.Net.Sockets; +using System.Diagnostics; using System.Reflection; using System.Runtime.CompilerServices; using System.Text; @@ -15,6 +16,57 @@ namespace GitHub.Copilot.Test.Unit; public sealed class ClientSessionLifetimeTests { + [Fact] + public async Task StopAsync_Requests_Runtime_Shutdown_For_Owned_Process() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await client.StartAsync(); + using var process = StartExitedProcess(); + await ReplaceConnectionCliProcessAsync(client, process); + + await client.StopAsync(); + + Assert.Equal(1, server.RuntimeShutdownCount); + } + + [Fact] + public async Task StopAsync_Does_Not_Throw_When_Runtime_Shutdown_Fails() + { + await using var server = await FakeCopilotServer.StartAsync(); + server.FailRuntimeShutdown(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await client.StartAsync(); + using var process = StartExitedProcess(); + await ReplaceConnectionCliProcessAsync(client, process); + + await client.StopAsync(); + + Assert.Equal(1, server.RuntimeShutdownCount); + } + + [Fact] + public async Task ForceStopAsync_And_External_Stop_Do_Not_Request_Runtime_Shutdown() + { + await using var forceServer = await FakeCopilotServer.StartAsync(); + await using var forceClient = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(forceServer.Url) }); + await forceClient.StartAsync(); + using var process = StartExitedProcess(); + await ReplaceConnectionCliProcessAsync(forceClient, process); + + await forceClient.ForceStopAsync(); + + Assert.Equal(0, forceServer.RuntimeShutdownCount); + + await using var externalServer = await FakeCopilotServer.StartAsync(); + await using var externalClient = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(externalServer.Url) }); + await externalClient.StartAsync(); + + await externalClient.StopAsync(); + + Assert.Equal(0, externalServer.RuntimeShutdownCount); + } + [Fact] public async Task Dropped_Session_Remains_Rooted_By_Client() { @@ -186,6 +238,37 @@ private static int GetPrivateDictionaryCount(CopilotClient client, string fieldN return (int)count.GetValue(dictionary)!; } + private static async Task ReplaceConnectionCliProcessAsync(CopilotClient client, Process process) + { + var field = typeof(CopilotClient).GetField("_connectionTask", BindingFlags.Instance | BindingFlags.NonPublic) + ?? throw new InvalidOperationException("_connectionTask field was not found."); + var connectionTask = (Task)field.GetValue(client)!; + await connectionTask; + + var resultProperty = connectionTask.GetType().GetProperty(nameof(Task.Result)) + ?? throw new InvalidOperationException("Connection task result property was not found."); + var connection = resultProperty.GetValue(connectionTask)!; + var connectionType = connection.GetType(); + var rpc = connectionType.GetProperty("Rpc")!.GetValue(connection); + var networkStream = connectionType.GetProperty("NetworkStream")!.GetValue(connection); + var constructor = connectionType.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(); + var updatedConnection = constructor.Invoke([rpc, process, networkStream, null]); + var fromResult = typeof(Task).GetMethod(nameof(Task.FromResult))!.MakeGenericMethod(connectionType); + field.SetValue(client, fromResult.Invoke(null, [updatedConnection])); + } + + private static Process StartExitedProcess() + { + var startInfo = OperatingSystem.IsWindows() + ? new ProcessStartInfo(Environment.GetEnvironmentVariable("COMSPEC") ?? "cmd.exe", "/c exit 0") + : new ProcessStartInfo("/bin/sh", "-c \"exit 0\""); + startInfo.UseShellExecute = false; + var process = Process.Start(startInfo) + ?? throw new InvalidOperationException("Failed to start test process."); + process.WaitForExit(); + return process; + } + private sealed class FakeCopilotServer : IAsyncDisposable { private readonly TcpListener _listener; @@ -196,6 +279,7 @@ private sealed class FakeCopilotServer : IAsyncDisposable private readonly Task _serverTask; private string? _lastSessionId; private bool _delayDestroy; + private bool _failRuntimeShutdown; private FakeCopilotServer(TcpListener listener) { @@ -221,6 +305,8 @@ public static Task StartAsync() public Task DestroyStarted => _destroyStarted.Task; + public int RuntimeShutdownCount { get; private set; } + public void DelayDestroy() { _delayDestroy = true; @@ -231,6 +317,11 @@ public void CompleteDestroy() _allowDestroy.TrySetResult(); } + public void FailRuntimeShutdown() + { + _failRuntimeShutdown = true; + } + public async ValueTask DisposeAsync() { _allowDestroy.TrySetResult(); @@ -275,6 +366,22 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel var id = idElement.Clone(); var method = request.GetProperty("method").GetString(); + if (method == "runtime.shutdown" && _failRuntimeShutdown) + { + RuntimeShutdownCount++; + await WriteMessageAsync(stream, new Dictionary + { + ["jsonrpc"] = "2.0", + ["id"] = id, + ["error"] = new Dictionary + { + ["code"] = -32000, + ["message"] = "runtime shutdown failed" + } + }, cancellationToken); + return; + } + object? result = method switch { "connect" => new Dictionary @@ -294,6 +401,7 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel ["success"] = true }, "session.destroy" => await DestroySessionAsync(cancellationToken), + "runtime.shutdown" => HandleRuntimeShutdown(), _ => throw new InvalidOperationException($"Unexpected RPC method '{method}'.") }; @@ -340,6 +448,12 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel return []; } + private Dictionary HandleRuntimeShutdown() + { + RuntimeShutdownCount++; + return []; + } + private async Task WriteMessageAsync(Stream stream, object payload, CancellationToken cancellationToken) { using var bodyStream = new MemoryStream(); diff --git a/go/client.go b/go/client.go index cad460557..0356af774 100644 --- a/go/client.go +++ b/go/client.go @@ -34,6 +34,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net" "os" "os/exec" @@ -378,8 +379,9 @@ func (c *Client) Start(ctx context.Context) error { // // This method performs graceful cleanup: // 1. Closes all active sessions (releases in-memory resources) -// 2. Closes the JSON-RPC connection -// 3. Terminates the CLI server process (if spawned by this client) +// 2. Requests runtime shutdown for SDK-owned CLI processes +// 3. Closes the JSON-RPC connection +// 4. Terminates the CLI server process (if spawned by this client) // // Note: session data on disk is preserved, so sessions can be resumed later. // To permanently remove session data before stopping, call [Client.DeleteSession] @@ -416,10 +418,54 @@ func (c *Client) Stop() error { c.startStopMux.Lock() defer c.startStopMux.Unlock() - // Kill CLI process FIRST (this closes stdout and unblocks readLoop) - only if we spawned it + runtimeShutdownCompleted := false + if c.process != nil && !c.isExternalServer && c.RPC != nil { + rpcClient := c.RPC + runtimeShutdownStart := time.Now() + shutdownDone := make(chan error, 1) + go func() { + _, err := rpcClient.Runtime.Shutdown(context.Background()) + shutdownDone <- err + }() + + select { + case err := <-shutdownDone: + if err != nil { + c.logDebugTiming(runtimeShutdownStart, "CopilotClient.Stop runtime shutdown failed") + errs = append(errs, fmt.Errorf("failed to gracefully shut down runtime: %w", err)) + } else { + runtimeShutdownCompleted = true + c.logDebugTiming(runtimeShutdownStart, "CopilotClient.Stop runtime shutdown complete") + } + case <-time.After(runtimeShutdownTimeout): + c.logDebugTiming(runtimeShutdownStart, "CopilotClient.Stop runtime shutdown timed out") + errs = append(errs, fmt.Errorf("timed out gracefully shutting down runtime after %s", runtimeShutdownTimeout)) + } + } + + // Give runtime.shutdown a bounded window to let the child exit on its own + // before falling back to killing it. if c.process != nil && !c.isExternalServer { - if err := c.killProcess(); err != nil { - errs = append(errs, err) + if c.processDone != nil { + if runtimeShutdownCompleted { + select { + case <-c.processDone: + c.osProcess.Swap(nil) + c.process = nil + case <-time.After(runtimeShutdownTimeout): + if err := c.killProcessAndWait(); err != nil { + errs = append(errs, err) + } + } + } else { + if err := c.killProcessAndWait(); err != nil { + errs = append(errs, err) + } + } + } else { + if err := c.killProcessAndWait(); err != nil { + errs = append(errs, err) + } } } c.process = nil @@ -453,6 +499,13 @@ func (c *Client) Stop() error { return errors.Join(errs...) } +func (c *Client) logDebugTiming(start time.Time, message string) { + switch strings.ToLower(c.options.LogLevel) { + case "debug", "all": + log.Printf("%s elapsed=%s", message, time.Since(start)) + } +} + // ForceStop forcefully stops the CLI server without graceful cleanup. // // Use this when [Client.Stop] fails or takes too long. This method: @@ -1548,6 +1601,8 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { // minProtocolVersion is the minimum protocol version this SDK can communicate with. const minProtocolVersion = 3 +const runtimeShutdownTimeout = 10 * time.Second +const processExitTimeout = 10 * time.Second // verifyProtocolVersion sends the `connect` handshake (carrying the optional token) and // verifies the server's protocol version. Falls back to `ping` against legacy servers @@ -1821,6 +1876,21 @@ func (c *Client) killProcess() error { return nil } +func (c *Client) killProcessAndWait() error { + done := c.processDone + killErr := c.killProcess() + if done == nil { + return killErr + } + + select { + case <-done: + return killErr + case <-time.After(processExitTimeout): + return errors.Join(killErr, fmt.Errorf("timed out waiting for CLI process to exit after kill")) + } +} + // monitorProcess signals when the CLI process exits and captures any exit error. // processError is intentionally a local: each process lifecycle gets its own // error value, so goroutines from previous processes can't overwrite the diff --git a/go/client_test.go b/go/client_test.go index 6236a95ab..f57753f25 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -3,6 +3,7 @@ package copilot import ( "context" "encoding/json" + "net" "os" "os/exec" "path/filepath" @@ -13,6 +14,7 @@ import ( "sync" "testing" + "github.com/github/copilot-sdk/go/internal/jsonrpc2" "github.com/github/copilot-sdk/go/internal/truncbuffer" "github.com/github/copilot-sdk/go/rpc" ) @@ -134,6 +136,88 @@ func TestClient_URLParsing(t *testing.T) { }) } +func TestClient_StopRequestsRuntimeShutdownForOwnedProcess(t *testing.T) { + rpcClient, server, shutdownCalled := newRuntimeShutdownRpcPair(t) + client := &Client{ + process: &exec.Cmd{}, + client: rpcClient, + RPC: rpc.NewServerRPC(rpcClient), + sessions: make(map[string]*Session), + processDone: make(chan struct{}), + } + close(client.processDone) + + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + + select { + case <-shutdownCalled: + default: + t.Fatal("Stop did not request runtime.shutdown") + } + + server.Stop() +} + +func TestClient_ForceStopAndExternalStopDoNotRequestRuntimeShutdown(t *testing.T) { + rpcClient, server, shutdownCalled := newRuntimeShutdownRpcPair(t) + client := &Client{ + process: &exec.Cmd{}, + client: rpcClient, + RPC: rpc.NewServerRPC(rpcClient), + sessions: make(map[string]*Session), + } + + client.ForceStop() + assertRuntimeShutdownNotCalled(t, shutdownCalled) + server.Stop() + + externalRpcClient, externalServer, externalShutdownCalled := newRuntimeShutdownRpcPair(t) + externalClient := &Client{ + client: externalRpcClient, + RPC: rpc.NewServerRPC(externalRpcClient), + sessions: make(map[string]*Session), + isExternalServer: true, + } + + if err := externalClient.Stop(); err != nil { + t.Fatalf("external Stop failed: %v", err) + } + assertRuntimeShutdownNotCalled(t, externalShutdownCalled) + externalServer.Stop() +} + +func newRuntimeShutdownRpcPair(t *testing.T) (*jsonrpc2.Client, *jsonrpc2.Client, chan struct{}) { + t.Helper() + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + clientConn.Close() + serverConn.Close() + }) + + rpcClient := jsonrpc2.NewClient(clientConn, clientConn) + server := jsonrpc2.NewClient(serverConn, serverConn) + shutdownCalled := make(chan struct{}, 1) + server.SetRequestHandler("runtime.shutdown", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + shutdownCalled <- struct{}{} + return []byte(`{}`), nil + }) + rpcClient.Start() + server.Start() + return rpcClient, server, shutdownCalled +} + +func assertRuntimeShutdownNotCalled(t *testing.T, shutdownCalled <-chan struct{}) { + t.Helper() + select { + case <-shutdownCalled: + t.Fatal("runtime.shutdown should not have been requested") + default: + } +} + func TestClient_SessionFSConfig(t *testing.T) { t.Run("should throw error when InitialWorkingDirectory is missing", func(t *testing.T) { defer func() { diff --git a/go/internal/e2e/telemetry_e2e_test.go b/go/internal/e2e/telemetry_e2e_test.go index 071030281..8f63d49a5 100644 --- a/go/internal/e2e/telemetry_e2e_test.go +++ b/go/internal/e2e/telemetry_e2e_test.go @@ -7,7 +7,6 @@ import ( "path/filepath" "strings" "testing" - "time" copilot "github.com/github/copilot-sdk/go" "github.com/github/copilot-sdk/go/internal/e2e/testharness" @@ -72,14 +71,7 @@ func TestTelemetryE2E(t *testing.T) { t.Logf("Stop returned: %v", err) } - entries, err := readTelemetryEntries(t, telemetryPath, 30*time.Second, func(es []map[string]any) bool { - for _, e := range es { - if telemetryType(e) == "span" && stringAttr(e, "gen_ai.operation.name") == "invoke_agent" { - return true - } - } - return false - }) + entries, err := readTelemetryEntries(t, telemetryPath) if err != nil { t.Fatalf("readTelemetryEntries failed: %v", err) } @@ -182,33 +174,27 @@ func TestTelemetryE2E(t *testing.T) { }) } -func readTelemetryEntries(t *testing.T, path string, timeout time.Duration, isComplete func([]map[string]any) bool) ([]map[string]any, error) { +func readTelemetryEntries(t *testing.T, path string) ([]map[string]any, error) { t.Helper() - deadline := time.Now().Add(timeout) - for time.Now().Before(deadline) { - if info, err := os.Stat(path); err == nil && info.Size() > 0 { - data, err := os.ReadFile(path) - if err == nil { - var entries []map[string]any - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if line == "" { - continue - } - var entry map[string]any - if err := json.Unmarshal([]byte(line), &entry); err != nil { - continue - } - entries = append(entries, entry) - } - if len(entries) > 0 && isComplete(entries) { - return entries, nil - } - } + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var entries []map[string]any + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + var entry map[string]any + if err := json.Unmarshal([]byte(line), &entry); err != nil { + return nil, fmt.Errorf("parse telemetry entry in %q: %w", path, err) } - time.Sleep(100 * time.Millisecond) + entries = append(entries, entry) } - return nil, fmt.Errorf("timed out waiting for telemetry records in %q", path) + return entries, nil } func telemetryType(e map[string]any) string { return stringProp(e, "type") } diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index 6422c773d..9f0f1748a 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -80,6 +80,7 @@ public final class CopilotClient implements AutoCloseable { * shutdown via {@link #stop()}. */ public static final int AUTOCLOSEABLE_TIMEOUT_SECONDS = 10; + private static final int RUNTIME_SHUTDOWN_TIMEOUT_SECONDS = 10; private static final int FORCE_KILL_TIMEOUT_SECONDS = 10; /** @@ -260,7 +261,7 @@ private Connection startCoreBody() { } // Clean up the spawned process if connection setup failed if (process != null) { - cleanupCliProcess(process); + cleanupCliProcess(process, true); } String stderr = serverManager.getStderrOutput(); if (!stderr.isEmpty()) { @@ -329,6 +330,7 @@ private static boolean isUnsupportedConnectMethod(JsonRpcException ex) { * This method performs graceful cleanup: *
    *
  1. Closes all active sessions (releases in-memory resources)
  2. + *
  3. Requests runtime shutdown for SDK-owned CLI processes
  4. *
  5. Closes the JSON-RPC connection
  6. *
  7. Terminates the CLI server process (if spawned by this client)
  8. *
@@ -363,7 +365,7 @@ public CompletableFuture stop() { sessions.clear(); return CompletableFuture.allOf(closeFutures.toArray(new CompletableFuture[0])) - .thenCompose(v -> cleanupConnection()); + .thenCompose(v -> cleanupConnection(true)); } /** @@ -379,10 +381,11 @@ public CompletableFuture forceStop() { // executor, so a plain whenComplete(...) here could land the awaitTermination // call on one of the very threads it is waiting to drain, forcing the full // AUTOCLOSEABLE_TIMEOUT_SECONDS timeout followed by shutdownNow(). - return cleanupConnection().whenCompleteAsync((ignored, error) -> shutdownOwnedExecutor(), SHUTDOWN_DISPATCHER); + return cleanupConnection(false).whenCompleteAsync((ignored, error) -> shutdownOwnedExecutor(), + SHUTDOWN_DISPATCHER); } - private CompletableFuture cleanupConnection() { + private CompletableFuture cleanupConnection(boolean gracefulRuntimeShutdown) { CompletableFuture future = connectionFuture; connectionFuture = null; @@ -393,27 +396,67 @@ private CompletableFuture cleanupConnection() { return CompletableFuture.completedFuture(null); } - return future.thenAccept(connection -> { - try { - connection.rpc.close(); - } catch (Exception e) { - LOG.log(Level.FINE, "Error closing RPC", e); + return future.handle((connection, startupError) -> { + if (startupError != null) { + LOG.log(Level.FINE, "Ignoring failed Copilot client startup during cleanup", startupError); + return CompletableFuture.completedFuture(null); } - if (connection.process != null) { - cleanupCliProcess(connection.process); + CompletableFuture shutdownFuture = CompletableFuture.completedFuture(null); + if (gracefulRuntimeShutdown && connection.process != null) { + long runtimeShutdownStartNanos = System.nanoTime(); + shutdownFuture = connection.rpc.invoke("runtime.shutdown", Map.of(), Void.class) + .orTimeout(RUNTIME_SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .whenComplete((ignored, error) -> { + if (error == null) { + LoggingHelpers.logTiming(LOG, Level.FINE, + "CopilotClient.stop runtime shutdown complete. Elapsed={Elapsed}", + runtimeShutdownStartNanos); + } else { + LoggingHelpers.logTiming(LOG, Level.FINE, error, + "CopilotClient.stop runtime shutdown failed. Elapsed={Elapsed}", + runtimeShutdownStartNanos); + } + }); } - }).exceptionally(ex -> { - LOG.log(Level.FINE, "Ignoring failed Copilot client startup during cleanup", ex); - return null; - }); + + return shutdownFuture.handle((ignored, error) -> { + try { + connection.rpc.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Error closing RPC", e); + } + + if (connection.process != null) { + cleanupCliProcess(connection.process, !gracefulRuntimeShutdown || error != null); + } + return (Void) null; + }); + }).thenCompose(result -> result); } - private static void cleanupCliProcess(Process process) { + private static void cleanupCliProcess(Process process, boolean forceImmediately) { try { if (process.isAlive()) { - Process destroyedProcess = process.destroyForcibly(); - if (!destroyedProcess.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + if (!forceImmediately && process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + return; + } + + if (forceImmediately) { + process.destroyForcibly(); + if (!process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + LOG.fine("Process did not terminate within force kill timeout"); + } + return; + } else { + process.destroy(); + } + if (!forceImmediately && process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + return; + } + + process.destroyForcibly(); + if (!process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { LOG.fine("Process did not terminate within force kill timeout"); } } diff --git a/java/src/test/java/com/github/copilot/CopilotClientTest.java b/java/src/test/java/com/github/copilot/CopilotClientTest.java index 1d6bfc704..7c97a3886 100644 --- a/java/src/test/java/com/github/copilot/CopilotClientTest.java +++ b/java/src/test/java/com/github/copilot/CopilotClientTest.java @@ -16,11 +16,15 @@ import java.lang.reflect.Field; import java.util.ArrayList; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import static org.junit.jupiter.api.Assertions.*; -import java.util.Optional; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; /** * Tests for CopilotClient. @@ -38,6 +42,70 @@ static void setup() { cliPath = TestUtil.findCliPath(); } + @Test + void testStopRequestsRuntimeShutdownForOwnedProcess() throws Exception { + var client = new CopilotClient(new CopilotClientOptions().setAutoStart(false)); + var rpc = mock(JsonRpcClient.class); + when(rpc.invoke(eq("runtime.shutdown"), any(), eq(Void.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + var process = mock(Process.class); + when(process.isAlive()).thenReturn(true); + when(process.waitFor(anyLong(), any(TimeUnit.class))).thenReturn(true); + + setConnectionFuture(client, rpc, process); + + client.stop().get(); + + verify(rpc).invoke(eq("runtime.shutdown"), eq(Map.of()), eq(Void.class)); + verify(rpc).close(); + verify(process, never()).destroy(); + verify(process, never()).destroyForcibly(); + } + + @Test + void testStopDoesNotThrowWhenRuntimeShutdownFails() throws Exception { + var client = new CopilotClient(new CopilotClientOptions().setAutoStart(false)); + var rpc = mock(JsonRpcClient.class); + when(rpc.invoke(eq("runtime.shutdown"), any(), eq(Void.class))) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException("shutdown failed"))); + var process = mock(Process.class); + when(process.isAlive()).thenReturn(true); + when(process.destroyForcibly()).thenReturn(process); + when(process.waitFor(anyLong(), any(TimeUnit.class))).thenReturn(true); + + setConnectionFuture(client, rpc, process); + + assertDoesNotThrow(() -> client.stop().get()); + + verify(rpc).invoke(eq("runtime.shutdown"), eq(Map.of()), eq(Void.class)); + verify(rpc).close(); + verify(process).destroyForcibly(); + } + + @Test + void testForceStopAndExternalStopDoNotRequestRuntimeShutdown() throws Exception { + var forceClient = new CopilotClient(new CopilotClientOptions().setAutoStart(false)); + var forceRpc = mock(JsonRpcClient.class); + var process = mock(Process.class); + when(process.isAlive()).thenReturn(true); + when(process.destroyForcibly()).thenReturn(process); + when(process.waitFor(anyLong(), any(TimeUnit.class))).thenReturn(true); + setConnectionFuture(forceClient, forceRpc, process); + + forceClient.forceStop().get(); + + verify(forceRpc, never()).invoke(eq("runtime.shutdown"), any(), eq(Void.class)); + verify(process).destroyForcibly(); + + var externalClient = new CopilotClient(new CopilotClientOptions().setAutoStart(false)); + var externalRpc = mock(JsonRpcClient.class); + setConnectionFuture(externalClient, externalRpc, null); + + externalClient.stop().get(); + + verify(externalRpc, never()).invoke(eq("runtime.shutdown"), any(), eq(Void.class)); + } + @Test void testClientConstruction() { var client = new CopilotClient(); @@ -533,4 +601,16 @@ void testListModels_WithCustomHandler_WorksWithoutStart() throws Exception { assertEquals("no-start-model", models.get(0).getId()); } } + + private static void setConnectionFuture(CopilotClient client, JsonRpcClient rpc, Process process) throws Exception { + var connectionClass = Class.forName("com.github.copilot.CopilotClient$Connection"); + var constructor = connectionClass.getDeclaredConstructor(JsonRpcClient.class, Process.class, + com.github.copilot.generated.rpc.ServerRpc.class); + constructor.setAccessible(true); + var connection = constructor.newInstance(rpc, process, null); + + Field field = CopilotClient.class.getDeclaredField("connectionFuture"); + field.setAccessible(true); + field.set(client, CompletableFuture.completedFuture(connection)); + } } diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 8dc35b8d7..e53d0f655 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -78,6 +78,7 @@ import { defaultJoinSessionPermissionHandler } from "./types.js"; * Servers reporting a version below this are rejected. */ const MIN_PROTOCOL_VERSION = 3; +const RUNTIME_SHUTDOWN_TIMEOUT_MS = 10_000; /** * Check if value is a Zod schema (has toJSONSchema method) @@ -91,6 +92,53 @@ function isZodSchema(value: unknown): value is { toJSONSchema(): Record(promise: Promise, timeoutMs: number, message: string): Promise { + let timeout: ReturnType | undefined; + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timeout = setTimeout(() => reject(new Error(message)), timeoutMs); + }), + ]); + } finally { + if (timeout !== undefined) { + clearTimeout(timeout); + } + } +} + +async function waitForChildExit(child: ChildProcess, timeoutMs: number): Promise { + if (child.exitCode != null || child.signalCode != null) { + return true; + } + + return new Promise((resolve) => { + let timeout: ReturnType; + let settled = false; + const onExit = () => { + if (settled) { + return; + } + settled = true; + clearTimeout(timeout); + resolve(true); + }; + timeout = setTimeout(() => { + if (settled) { + return; + } + settled = true; + child.off("exit", onExit); + resolve(false); + }, timeoutMs); + child.once("exit", onExit); + if (child.exitCode != null || child.signalCode != null) { + onExit(); + } + }); +} + /** * Convert tool parameters to JSON schema format for sending to CLI */ @@ -370,6 +418,13 @@ export class CopilotClient { return this._internalRpc; } + private logDebugTiming(message: string, startMs: number): void { + const level = this.options.logLevel?.toLowerCase(); + if (level === "debug" || level === "all") { + process.stderr.write(`[copilot-sdk] ${message}. Elapsed=${Date.now() - startMs}ms\n`); + } + } + /** * Creates a new CopilotClient instance. * @@ -620,8 +675,9 @@ export class CopilotClient { * * This method performs graceful cleanup: * 1. Closes all active sessions (releases in-memory resources) - * 2. Closes the JSON-RPC connection - * 3. Terminates the CLI server process (if spawned by this client) + * 2. Requests runtime shutdown for SDK-owned CLI processes + * 3. Closes the JSON-RPC connection + * 4. Terminates the CLI server process (if spawned by this client) * * Note: session data on disk is preserved, so sessions can be resumed later. * To permanently remove session data before stopping, call @@ -673,6 +729,38 @@ export class CopilotClient { } this.sessions.clear(); + // Ask SDK-owned runtimes to flush and clean up before we tear down + // their transport/process. External runtimes may be shared, so only + // close our connection to them. + let runtimeShutdownCompleted = false; + if (this.connection && this.cliProcess && !this.isExternalServer) { + const runtimeShutdownStart = Date.now(); + const shutdownPromise = this.rpc.runtime.shutdown(); + void shutdownPromise.catch(() => undefined); + try { + await withTimeout( + shutdownPromise, + RUNTIME_SHUTDOWN_TIMEOUT_MS, + `runtime.shutdown timed out after ${RUNTIME_SHUTDOWN_TIMEOUT_MS}ms` + ); + runtimeShutdownCompleted = true; + this.logDebugTiming( + "CopilotClient.stop runtime shutdown complete", + runtimeShutdownStart + ); + } catch (error) { + this.logDebugTiming( + "CopilotClient.stop runtime shutdown failed", + runtimeShutdownStart + ); + errors.push( + new Error( + `Failed to gracefully shut down runtime: ${error instanceof Error ? error.message : String(error)}` + ) + ); + } + } + // Close connection if (this.connection) { try { @@ -686,6 +774,7 @@ export class CopilotClient { } this.connection = null; this._rpc = null; + this._internalRpc = null; } // Clear models cache @@ -711,19 +800,26 @@ export class CopilotClient { } } - // Send SIGTERM and await child exit. If the child ignores SIGTERM we - // intentionally block here — callers who need a guaranteed-bounded - // shutdown should reach for forceStop() instead, which sends SIGKILL. + // Give runtime.shutdown a bounded window to let the child exit on its + // own before falling back to SIGTERM. if (this.cliProcess && !this.isExternalServer) { const child = this.cliProcess; this.cliProcess = null; try { - if (child.exitCode === null && child.signalCode === null) { - const exited = new Promise((resolve) => { - child.once("exit", () => resolve()); - }); - child.kill(); - await exited; + if (child.exitCode == null && child.signalCode == null) { + const exitedGracefully = runtimeShutdownCompleted + ? await waitForChildExit(child, RUNTIME_SHUTDOWN_TIMEOUT_MS) + : false; + if (!exitedGracefully) { + child.kill(); + if (!(await waitForChildExit(child, RUNTIME_SHUTDOWN_TIMEOUT_MS))) { + errors.push( + new Error( + `Timed out waiting for CLI process to exit after kill: ${RUNTIME_SHUTDOWN_TIMEOUT_MS}ms` + ) + ); + } + } } } catch (error) { errors.push( @@ -802,6 +898,7 @@ export class CopilotClient { } this.connection = null; this._rpc = null; + this._internalRpc = null; } // Clear models cache diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 657ec7c9c..13d741101 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -1,4 +1,5 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import { EventEmitter } from "node:events"; import { describe, expect, it, onTestFinished, vi } from "vitest"; import { approveAll, @@ -2194,4 +2195,77 @@ describe("CopilotClient", () => { }); }); }); + + describe("shutdown", () => { + it("requests runtime shutdown when stopping an SDK-owned process", async () => { + const client = new CopilotClient(); + const calls: string[] = []; + const child = new EventEmitter() as EventEmitter & { + exitCode: number | null; + signalCode: string | null; + kill: ReturnType; + }; + child.exitCode = null; + child.signalCode = null; + child.kill = vi.fn(() => { + calls.push("kill"); + child.signalCode = "SIGTERM"; + child.emit("exit", null, "SIGTERM"); + return true; + }); + + (client as any).connection = { + sendRequest: vi.fn(async (method: string) => { + calls.push(method); + if (method === "runtime.shutdown") { + child.exitCode = 0; + child.emit("exit", 0, null); + return {}; + } + throw new Error(`unexpected method ${method}`); + }), + dispose: vi.fn(() => calls.push("dispose")), + }; + (client as any).cliProcess = child; + (client as any).isExternalServer = false; + + await expect(client.stop()).resolves.toEqual([]); + expect(calls).toEqual(["runtime.shutdown", "dispose"]); + expect(child.kill).not.toHaveBeenCalled(); + }); + + it("does not request runtime shutdown for force stop or external runtimes", async () => { + const forceClient = new CopilotClient(); + const forceChild = new EventEmitter() as EventEmitter & { + exitCode: number | null; + signalCode: string | null; + kill: ReturnType; + }; + forceChild.exitCode = null; + forceChild.signalCode = null; + forceChild.kill = vi.fn(() => true); + const forceSendRequest = vi.fn(); + (forceClient as any).connection = { + sendRequest: forceSendRequest, + dispose: vi.fn(), + }; + (forceClient as any).cliProcess = forceChild; + (forceClient as any).isExternalServer = false; + + await forceClient.forceStop(); + expect(forceSendRequest).not.toHaveBeenCalled(); + expect(forceChild.kill).toHaveBeenCalledWith("SIGKILL"); + + const externalClient = new CopilotClient(); + const externalSendRequest = vi.fn(); + (externalClient as any).connection = { + sendRequest: externalSendRequest, + dispose: vi.fn(), + }; + (externalClient as any).isExternalServer = true; + + await expect(externalClient.stop()).resolves.toEqual([]); + expect(externalSendRequest).not.toHaveBeenCalled(); + }); + }); }); diff --git a/nodejs/test/e2e/telemetry.e2e.test.ts b/nodejs/test/e2e/telemetry.e2e.test.ts index a71dad93d..9df6b7f88 100644 --- a/nodejs/test/e2e/telemetry.e2e.test.ts +++ b/nodejs/test/e2e/telemetry.e2e.test.ts @@ -2,7 +2,6 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { existsSync, statSync } from "fs"; import { readFile } from "fs/promises"; import { join } from "path"; import { describe, expect, it } from "vitest"; @@ -34,32 +33,19 @@ function isRootSpan(entry: TelemetryEntry): boolean { return parent === "" || parent === "0000000000000000"; } -async function readTelemetryEntries( - path: string, - isComplete: (entries: TelemetryEntry[]) => boolean, - timeoutMs = 30_000 -): Promise { - const deadline = Date.now() + timeoutMs; - while (Date.now() < deadline) { - if (existsSync(path) && statSync(path).size > 0) { - const content = await readFile(path, "utf8"); - const entries: TelemetryEntry[] = []; - for (const line of content.split("\n")) { - const trimmed = line.trim(); - if (!trimmed) continue; - try { - entries.push(JSON.parse(trimmed)); - } catch { - // Skip malformed lines (file may still be writing) - } - } - if (entries.length > 0 && isComplete(entries)) { - return entries; - } +async function readTelemetryEntries(path: string): Promise { + const content = await readFile(path, "utf8"); + const entries: TelemetryEntry[] = []; + for (const line of content.split("\n")) { + const trimmed = line.trim(); + if (!trimmed) { + continue; } - await new Promise((resolve) => setTimeout(resolve, 100)); + + entries.push(JSON.parse(trimmed)); } - throw new Error(`Timed out waiting for telemetry records in '${path}'.`); + + return entries; } describe("Telemetry export", async () => { @@ -103,13 +89,7 @@ describe("Telemetry export", async () => { // Telemetry exporter writes to telemetryFileName resolved relative to the CLI cwd (workDir). const telemetryPath = join(workDir, telemetryFileName); - const entries = await readTelemetryEntries(telemetryPath, (entries) => - entries.some( - (entry) => - entry.type === "span" && - getStringAttribute(entry, "gen_ai.operation.name") === "invoke_agent" - ) - ); + const entries = await readTelemetryEntries(telemetryPath); const spans = entries.filter((entry) => entry.type === "span"); expect(spans.length).toBeGreaterThan(0); diff --git a/python/copilot/client.py b/python/copilot/client.py index 7dcec6e8f..f14a690e1 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -980,6 +980,8 @@ def _session_lifecycle_event_from_dict(data: dict) -> SessionLifecycleEvent: # Minimum protocol version this SDK can communicate with. # Servers reporting a version below this are rejected. _MIN_PROTOCOL_VERSION = 3 +_RUNTIME_SHUTDOWN_TIMEOUT_SECONDS = 10 +_CLI_PROCESS_EXIT_TIMEOUT_SECONDS = 5 def _get_bundled_cli_path() -> str | None: @@ -1225,7 +1227,8 @@ def __init__( if options.use_logged_in_user is None: options.use_logged_in_user = not bool(options.github_token) - self._process: subprocess.Popen | None = None + self._process: Any = None + self._cli_process: subprocess.Popen | None = None self._client: JsonRpcClient | None = None self._state: _ConnectionState = "disconnected" self._sessions: dict[str, CopilotSession] = {} @@ -1422,8 +1425,9 @@ async def start(self) -> None: exc_info=True, ) # Check if process exited and capture any remaining stderr - if self._process and hasattr(self._process, "poll"): - return_code = self._process.poll() + process = self._cli_process if self._cli_process is not None else self._process + if process and hasattr(process, "poll"): + return_code = process.poll() if return_code is not None and self._client: stderr_output = self._client.get_stderr_output() if stderr_output: @@ -1438,8 +1442,9 @@ async def stop(self) -> None: This method performs graceful cleanup: 1. Closes all active sessions (releases in-memory resources) - 2. Closes the JSON-RPC connection - 3. Terminates the CLI server process (if spawned by this client) + 2. Requests runtime shutdown for SDK-owned CLI processes + 3. Closes the JSON-RPC connection + 4. Terminates the CLI server process (if spawned by this client) Note: session data on disk is preserved, so sessions can be resumed later. To permanently remove session data before stopping, call @@ -1476,6 +1481,28 @@ async def stop(self) -> None: StopError(message=f"Failed to disconnect session {session.session_id}: {e}") ) + runtime_shutdown_completed = False + if self._rpc is not None and self._cli_process is not None and not self._is_external_server: + runtime_shutdown_start = time.perf_counter() + try: + await self._rpc.runtime.shutdown(timeout=_RUNTIME_SHUTDOWN_TIMEOUT_SECONDS) + runtime_shutdown_completed = True + log_timing( + logger, + logging.DEBUG, + "CopilotClient.stop runtime shutdown complete", + runtime_shutdown_start, + ) + except Exception as e: + log_timing( + logger, + logging.DEBUG, + "CopilotClient.stop runtime shutdown failed", + runtime_shutdown_start, + exc_info=True, + ) + errors.append(StopError(message=f"Failed to gracefully shut down runtime: {e}")) + # Close client if self._client: await self._client.stop() @@ -1486,15 +1513,74 @@ async def stop(self) -> None: async with self._models_cache_lock: self._models_cache = None - # Kill CLI process (only if we spawned it) - if self._process and not self._is_external_server: - self._process.terminate() + # Close TCP socket wrappers without treating them as owned processes. + if self._process is not None and self._process is not self._cli_process: try: - self._process.wait(timeout=5) - except subprocess.TimeoutExpired: - self._process.kill() + self._process.terminate() + except Exception: + logger.debug("Error while closing Copilot runtime transport", exc_info=True) self._process = None + # Terminate CLI process (only if we spawned it) + if self._cli_process and not self._is_external_server: + poll = getattr(self._cli_process, "poll", None) + is_running = poll is None or poll() is None + if is_running: + if runtime_shutdown_completed: + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_RUNTIME_SHUTDOWN_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired: + self._cli_process.terminate() + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired: + self._cli_process.kill() + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired as e: + errors.append( + StopError( + message=( + "Timed out waiting for CLI process to exit after kill: " + f"{e}" + ) + ) + ) + else: + self._cli_process.terminate() + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired: + self._cli_process.kill() + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired as e: + errors.append( + StopError( + message=( + f"Timed out waiting for CLI process to exit after kill: {e}" + ) + ) + ) + if self._process is self._cli_process: + self._process = None + self._cli_process = None + self._state = "disconnected" if not self._is_external_server: self._runtime_port = None @@ -1525,13 +1611,20 @@ async def force_stop(self) -> None: # Close the transport first to signal the server immediately. # For external servers (TCP), this closes the socket. # For spawned processes (stdio), this kills the process. - if self._process: + if self._process is not None or self._cli_process is not None: try: if self._is_external_server: - self._process.terminate() # closes the TCP socket + if self._process is not None: + self._process.terminate() # closes the TCP socket + self._process = None + self._cli_process = None else: - self._process.kill() + if self._process is not None and self._process is not self._cli_process: + self._process.terminate() + if self._cli_process is not None: + self._cli_process.kill() self._process = None + self._cli_process = None except Exception: logger.debug("Error while force-stopping Copilot CLI process", exc_info=True) @@ -3252,6 +3345,7 @@ async def _start_cli_server(self) -> None: env=env, creationflags=creationflags, ) + self._cli_process = self._process else: if tcp_port > 0: args.extend(["--port", str(tcp_port)]) @@ -3264,6 +3358,7 @@ async def _start_cli_server(self) -> None: env=env, creationflags=creationflags, ) + self._cli_process = self._process log_timing( logger, logging.DEBUG, diff --git a/python/e2e/test_telemetry_e2e.py b/python/e2e/test_telemetry_e2e.py index f18a9fb88..562fd658c 100644 --- a/python/e2e/test_telemetry_e2e.py +++ b/python/e2e/test_telemetry_e2e.py @@ -13,7 +13,6 @@ from __future__ import annotations -import asyncio import json import os import uuid @@ -45,22 +44,14 @@ def _is_root_span(entry: dict[str, Any]) -> bool: return parent in ("", "0000000000000000") -async def _read_telemetry_entries( - path: Path, complete: Any, *, timeout: float = 30.0 -) -> list[dict[str, Any]]: - deadline = asyncio.get_event_loop().time() + timeout - while asyncio.get_event_loop().time() < deadline: - if path.exists() and path.stat().st_size > 0: - entries: list[dict[str, Any]] = [] - for line in path.read_text(encoding="utf-8").splitlines(): - line = line.strip() - if not line: - continue - entries.append(json.loads(line)) - if entries and complete(entries): - return entries - await asyncio.sleep(0.1) - raise TimeoutError(f"Timed out waiting for telemetry records in '{path}'.") +def _read_telemetry_entries(path: Path) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + entries.append(json.loads(line)) + return entries class TestTelemetryExport: @@ -119,14 +110,7 @@ def echo(invocation: ToolInvocation) -> ToolResult: finally: await client.stop() - entries = await _read_telemetry_entries( - telemetry_path, - lambda items: any( - item.get("type") == "span" - and _string_attribute(item, "gen_ai.operation.name") == "invoke_agent" - for item in items - ), - ) + entries = _read_telemetry_entries(telemetry_path) spans = [item for item in entries if item.get("type") == "span"] assert spans diff --git a/python/test_client.py b/python/test_client.py index 502d410ab..0ea202152 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -5,7 +5,7 @@ """ from datetime import UTC, datetime -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -27,6 +27,73 @@ from e2e.testharness import CLI_PATH +class TestClientShutdown: + @pytest.mark.asyncio + async def test_stop_requests_runtime_shutdown_for_owned_process(self): + calls: list[str] = [] + process = Mock() + process.poll.return_value = None + process.wait.return_value = 0 + + class Runtime: + async def shutdown(self, *, timeout=None): + calls.append("runtime.shutdown") + + client = CopilotClient(connection=RuntimeConnection.for_stdio(path="copilot")) + client._rpc = Mock(runtime=Runtime()) + client._process = process + client._cli_process = process + client._is_external_server = False + + await client.stop() + + assert calls == ["runtime.shutdown"] + process.terminate.assert_not_called() + process.kill.assert_not_called() + + @pytest.mark.asyncio + async def test_force_stop_and_external_stop_do_not_request_runtime_shutdown(self): + calls: list[str] = [] + process = Mock() + + class Runtime: + async def shutdown(self): + calls.append("runtime.shutdown") + + force_client = CopilotClient(connection=RuntimeConnection.for_stdio(path="copilot")) + force_client._rpc = Mock(runtime=Runtime()) + force_client._process = process + force_client._cli_process = process + force_client._is_external_server = False + + await force_client.force_stop() + + assert calls == [] + process.kill.assert_called_once() + + external_client = CopilotClient(connection=RuntimeConnection.for_uri("localhost:1234")) + external_client._rpc = Mock(runtime=Runtime()) + external_client._is_external_server = True + + await external_client.stop() + + assert calls == [] + + @pytest.mark.asyncio + async def test_force_stop_external_server_clears_process_references(self): + process = Mock() + client = CopilotClient(connection=RuntimeConnection.for_uri("localhost:1234")) + client._is_external_server = True + client._process = process + client._cli_process = process + + await client.force_stop() + + process.terminate.assert_called_once() + assert client._process is None + assert client._cli_process is None + + class TestPermissionHandlerOptional: @pytest.mark.asyncio async def test_create_session_allows_missing_permission_handler(self): diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cab34b476..a66418241 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -59,7 +59,7 @@ use std::ffi::OsString; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::sync::{Arc, OnceLock}; -use std::time::Instant; +use std::time::{Duration, Instant}; use async_trait::async_trait; // JSON-RPC wire types are internal transport details (like Go SDK's internal/jsonrpc2/). @@ -91,6 +91,7 @@ pub use subscription::{EventSubscription, LifecycleSubscription}; /// Minimum protocol version this SDK can communicate with. const MIN_PROTOCOL_VERSION: u32 = 3; +const RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10); /// How the SDK communicates with the CLI server. #[derive(Debug, Default)] @@ -1818,8 +1819,9 @@ impl Client { /// Cooperatively shut down the client and the CLI child process. /// /// Walks every still-registered session and sends `session.destroy` - /// for each one, then kills the CLI child. Errors from per-session - /// destroys and the final child-kill are collected into + /// for each one, asks SDK-owned runtimes to shut down, then kills the + /// CLI child. Errors from per-session destroys, runtime shutdown, and + /// the final child-kill are collected into /// [`StopErrors`] rather than short-circuiting on the first failure /// — so callers see the full picture of teardown. /// @@ -1868,13 +1870,67 @@ impl Client { self.inner.router.unregister(&session_id); } + let should_shutdown_runtime = self.inner.child.lock().is_some(); + let mut runtime_shutdown_completed = false; + if should_shutdown_runtime { + let runtime_shutdown_start = Instant::now(); + match tokio::time::timeout(RUNTIME_SHUTDOWN_TIMEOUT, self.rpc().runtime().shutdown()) + .await + { + Ok(Ok(())) => { + runtime_shutdown_completed = true; + debug!( + elapsed_ms = runtime_shutdown_start.elapsed().as_millis(), + "Client::stop runtime shutdown complete" + ); + } + Ok(Err(e)) => { + warn!( + elapsed_ms = runtime_shutdown_start.elapsed().as_millis(), + error = %e, + "runtime.shutdown failed during Client::stop", + ); + errors.push(e); + } + Err(_) => { + let e = std::io::Error::new( + std::io::ErrorKind::TimedOut, + "runtime.shutdown timed out during Client::stop", + ); + warn!( + elapsed_ms = runtime_shutdown_start.elapsed().as_millis(), + timeout = ?RUNTIME_SHUTDOWN_TIMEOUT, + error = %e, + "runtime.shutdown timed out during Client::stop", + ); + errors.push(e.into()); + } + } + } + let child = self.inner.child.lock().take(); *self.inner.state.lock() = ConnectionState::Disconnected; *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new()); - if let Some(mut child) = child - && let Err(e) = child.kill().await - { - errors.push(e.into()); + if let Some(mut child) = child { + match child.try_wait() { + Ok(Some(_status)) => {} + Ok(None) => { + if runtime_shutdown_completed { + match tokio::time::timeout(RUNTIME_SHUTDOWN_TIMEOUT, child.wait()).await { + Ok(Ok(_status)) => {} + Ok(Err(e)) => errors.push(e.into()), + Err(_) => { + if let Err(e) = child.kill().await { + errors.push(e.into()); + } + } + } + } else if let Err(e) = child.kill().await { + errors.push(e.into()); + } + } + Err(e) => errors.push(e.into()), + } } info!(pid = ?pid, errors = errors.len(), "CLI process stopped"); diff --git a/rust/tests/e2e/telemetry.rs b/rust/tests/e2e/telemetry.rs index 10111be52..38bf4a404 100644 --- a/rust/tests/e2e/telemetry.rs +++ b/rust/tests/e2e/telemetry.rs @@ -9,7 +9,7 @@ use github_copilot_sdk::{ }; use serde_json::json; -use super::support::{assistant_message_content, wait_for_condition, with_e2e_context}; +use super::support::{assistant_message_content, with_e2e_context}; #[tokio::test] async fn should_export_file_telemetry_for_sdk_interactions() { @@ -66,7 +66,7 @@ async fn should_export_file_telemetry_for_sdk_interactions() { session.disconnect().await.expect("disconnect session"); client.stop().await.expect("stop client"); - let entries = read_telemetry_entries(&telemetry_path).await; + let entries = read_telemetry_entries(&telemetry_path); let spans: Vec<_> = entries .iter() .filter(|entry| string_property(entry, "type") == Some("span")) @@ -155,34 +155,13 @@ impl ToolHandler for EchoTelemetryTool { } } -async fn read_telemetry_entries(path: &std::path::Path) -> Vec { - wait_for_condition("telemetry file to contain spans", || { - let path = path.to_path_buf(); - async move { - read_telemetry_entries_once(&path).is_ok_and(|entries| { - entries.iter().any(|entry| { - string_property(entry, "type") == Some("span") - && string_attribute(entry, "gen_ai.operation.name").as_deref() - == Some("invoke_agent") - }) - }) - } - }) - .await; - read_telemetry_entries_once(path).expect("read telemetry entries") -} - -fn read_telemetry_entries_once(path: &std::path::Path) -> std::io::Result> { - if !path.exists() || path.metadata()?.len() == 0 { - return Ok(Vec::new()); - } - std::fs::read_to_string(path).map(|content| { - content - .lines() - .filter(|line| !line.trim().is_empty()) - .map(|line| serde_json::from_str(line).expect("telemetry JSON line")) - .collect() - }) +fn read_telemetry_entries(path: &std::path::Path) -> Vec { + std::fs::read_to_string(path) + .expect("read telemetry entries") + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| serde_json::from_str(line).expect("telemetry JSON line")) + .collect() } fn find_span<'a>(spans: &'a [&'a serde_json::Value], operation: &str) -> &'a serde_json::Value {