From 289dc5b0384026274186c2c746cf9a2e8ed550f0 Mon Sep 17 00:00:00 2001 From: Kern Walster Date: Wed, 20 May 2026 21:58:40 -0700 Subject: [PATCH] shim: signal vminitd to shut down gracefully before stopping the VM Before killing the VM, the host shim now calls the guest task service's Shutdown RPC. This gives vminitd a chance to run its registered shutdown callbacks (DHCP release, socket/stream teardown, etc.) before the VM is forcibly stopped. On the guest side, the Shutdown handler is updated to wait for shutdownSvc.Done() before returning, mirroring the host shim's own Shutdown behaviour and letting the host know cleanup is complete. To avoid a deadlock (ts.Shutdown was a shutdown callback that would wait for in-flight handlers, while the handler waited for callbacks to complete), ts.Shutdown is removed from the shutdown service callbacks. Signed-off-by: Kern Walster --- internal/shim/task/service.go | 13 ++++++++----- internal/vminit/task/service.go | 13 ++++++++----- pkg/vminit/initd/initd.go | 1 - 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/internal/shim/task/service.go b/internal/shim/task/service.go index 1e63ea4f..15a6088c 100644 --- a/internal/shim/task/service.go +++ b/internal/shim/task/service.go @@ -152,7 +152,7 @@ func (s *service) shutdown(ctx context.Context) error { // to flush ext4 journals and dirty pages to the virtio-blk devices. // Best-effort with a short retry for transient EBUSY. if vmc, err := s.sb.Client(); err != nil { - log.G(ctx).WithError(err).Warn("failed to get VM client; skipping unmount of block volumes before VM shutdown") + log.G(ctx).WithError(err).Warn("failed to get VM client; skipping unmount and guest shutdown before VM stop") } else { unmountCtx, cancel := context.WithTimeout(ctx, 30*time.Second) err := unmountAllWithRetry(unmountCtx, mountAPI.NewTTRPCMountClient(vmc)) @@ -160,6 +160,13 @@ func (s *service) shutdown(ctx context.Context) error { if err != nil { log.G(ctx).WithError(err).Warn("failed to unmount all block volumes before VM shutdown") } + + shutdownCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + _, err = taskAPI.NewTTRPCTaskClient(vmc).Shutdown(shutdownCtx, &taskAPI.ShutdownRequest{}) + cancel() + if err != nil { + log.G(ctx).WithError(err).Warn("failed to wait for guest shutdown before VM stop") + } } if err := s.sb.Stop(ctx); err != nil { @@ -709,10 +716,6 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) { log.G(ctx).WithFields(log.Fields{"id": r.ID}).Info("shutdown") - // TODO: Should we forward this to VM? - // tc := taskAPI.NewTTRPCTaskClient(s.vm.Client()) - // return tc.Shutdown(ctx, r) - s.initiateShutdownOnce.Do(s.initiateShutdown) select { diff --git a/internal/vminit/task/service.go b/internal/vminit/task/service.go index e67e36e4..b0459cb2 100644 --- a/internal/vminit/task/service.go +++ b/internal/vminit/task/service.go @@ -649,18 +649,21 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) { s.mu.Lock() - defer s.mu.Unlock() - // return out if the shim is still servicing containers if len(s.containers) > 0 { + s.mu.Unlock() return empty, nil } + s.mu.Unlock() - // please make sure that temporary resource has been cleanup or registered - // for cleanup before calling shutdown s.shutdown.Shutdown() - return empty, nil + select { + case <-s.shutdown.Done(): + return empty, nil + case <-ctx.Done(): + return nil, errgrpc.ToGRPC(ctx.Err()) + } } func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { diff --git a/pkg/vminit/initd/initd.go b/pkg/vminit/initd/initd.go index 78d1503b..28e0560f 100644 --- a/pkg/vminit/initd/initd.go +++ b/pkg/vminit/initd/initd.go @@ -307,7 +307,6 @@ func newService(ctx context.Context, config Config, shutdownSvc shutdown.Service if err != nil { return nil, err } - shutdownSvc.RegisterCallback(ts.Shutdown) registry.Register(&plugin.Registration{ Type: cplugins.InternalPlugin,