From 9080e2115864242014fb74f0dd49cabbbd3bf2c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Fri, 8 May 2026 18:13:18 +0200 Subject: [PATCH 1/7] Properly shutdown the server When starting the shutdown connection use a new context with specific timeout since the one used in the server has already been canceled. --- config.go | 3 +++ server.go | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 3e55ecb..d7a9481 100644 --- a/config.go +++ b/config.go @@ -21,6 +21,7 @@ const ( defaultServerResultStreamMaxWait = 20 * time.Second defaultServerMaxRequestBodySize int64 = 8 << 10 // 8KiB defaultServerCascadeLabels string = "" // 8KiB + defaultServerShutdownTimeout = 40 * time.Second defaultCircuitHalfOpenSuccesses = 10 defaultCircuitOpenTimeout = 0 @@ -56,6 +57,7 @@ var config struct { ResultStreamMaxWait time.Duration MaxRequestBodySize int64 CascadeLabels string + ShutdownTimeout time.Duration TopProviderCardinality int TopProviderReportInterval time.Duration } @@ -83,6 +85,7 @@ func init() { config.Server.ResultStreamMaxWait = getEnvOrDefault[time.Duration]("SERVER_RESULT_STREAM_MAX_WAIT", defaultServerResultStreamMaxWait) config.Server.MaxRequestBodySize = getEnvOrDefault[int64]("SERVER_MAX_REQUEST_BODY_SIZE", defaultServerMaxRequestBodySize) config.Server.CascadeLabels = getEnvOrDefault[string]("SERVER_CASCADE_LABELS", defaultServerCascadeLabels) + config.Server.ShutdownTimeout = getEnvOrDefault[time.Duration]("SERVER_SHUTDOWN_TIMEOUT", defaultServerShutdownTimeout) config.Server.TopProviderCardinality = getEnvOrDefault[int]("SERVER_TOP_PROVIDER_CARDINALITY", defaultStatMaxProviders) config.Server.TopProviderReportInterval = getEnvOrDefault[time.Duration]("SERVER_TOP_PROVIDER_REPORT_INVERVAL", defaultStatProviderReportUpdate) diff --git a/server.go b/server.go index e94bf6a..7bb1135 100644 --- a/server.go +++ b/server.go @@ -333,12 +333,23 @@ func (s *server) Serve() chan error { defer close(ec) <-s.Context.Done() - err := serv.Shutdown(s.Context) + + shutdownCtx, cancel := context.WithTimeout(context.Background(), config.Server.ShutdownTimeout) + defer cancel() + + err := serv.Shutdown(shutdownCtx) if err != nil { - log.Warnw("failed shutdown", "err", err) + log.Warnw("failed to shutdown", "err", err) + ec <- err + } + + err = metricsServ.Shutdown(shutdownCtx) + if err != nil { + log.Warnw("failed to shutdown metrics server", "err", err) ec <- err } }() + return ec } From c265dcc36727891f84a883e7308cb487fde291c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Mon, 18 May 2026 12:48:40 +0200 Subject: [PATCH 2/7] Don't leak listeners on initialization errors --- server.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 7bb1135..85a5e12 100644 --- a/server.go +++ b/server.go @@ -63,14 +63,28 @@ type providersBackend struct { } func NewServer(c *cli.Context) (*server, error) { - bound, err := net.Listen("tcp", c.String("listen")) + var lc net.ListenConfig + + bound, err := lc.Listen(c.Context, "tcp", c.String("listen")) if err != nil { return nil, err } - mb, err := net.Listen("tcp", c.String("metrics")) + defer func() { + if bound != nil { + bound.Close() + } + }() + + mb, err := lc.Listen(c.Context, "tcp", c.String("metrics")) if err != nil { return nil, err } + defer func() { + if mb != nil { + mb.Close() + } + }() + servers := c.StringSlice(backendsArg) cascadeServers := c.StringSlice(cascadeBackendsArg) dhServers := c.StringSlice(dhBackendsArg) @@ -158,6 +172,9 @@ func NewServer(c *cli.Context) (*server, error) { pcounts: pCounts, } + // Listeners propagated to the server, don't close on defer + bound, mb = nil, nil + go func() { for { select { From f007d50622a3f29c41a5df20bc29bf5429f7e922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Mon, 18 May 2026 12:50:06 +0200 Subject: [PATCH 3/7] Properly propagate process termination through context --- main.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index cbb3fcd..1cfd050 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "os" @@ -17,6 +18,9 @@ import ( const configCheckInterval = 5 * time.Second func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + defer stop() + app := &cli.App{ Name: "indexstar", Usage: "indexstar is a point in the content routing galaxy - routes requests in a star topology", @@ -64,8 +68,6 @@ func main() { }, }, Action: func(c *cli.Context) error { - exit := make(chan os.Signal, 1) - signal.Notify(exit, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) s, err := NewServer(c) if err != nil { @@ -112,8 +114,6 @@ func main() { case reloadSig <- struct{}{}: default: } - case <-exit: - return nil case err := <-done: return err case <-reloadSig: @@ -138,11 +138,12 @@ func main() { } }, } - err := app.Run(os.Args) + err := app.RunContext(ctx, os.Args) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } + os.Exit(0) } From e5234d0c386e49fa0182bb2531bff99bdf2f25c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Mon, 18 May 2026 12:52:46 +0200 Subject: [PATCH 4/7] Simplify config reload handling * Use non-blocking send when config time change is detected avoiding rare deadlock * Use same channel for signal detection for config file time change detection --- main.go | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index 1cfd050..3633fd3 100644 --- a/main.go +++ b/main.go @@ -74,8 +74,8 @@ func main() { return err } - sighup := make(chan os.Signal, 1) - signal.Notify(sighup, syscall.SIGHUP) + reloadSig := make(chan os.Signal, 1) + signal.Notify(reloadSig, syscall.SIGHUP) done := s.Serve() @@ -106,22 +106,12 @@ func main() { } } - reloadSig := make(chan struct{}, 1) for { select { - case <-sighup: - select { - case reloadSig <- struct{}{}: - default: - } case err := <-done: return err - case <-reloadSig: - err := s.Reload(c) - if err != nil { - log.Warnf("couldn't reload servers: %s", err) - } case <-timeChan: + // Detect config file changes and reload config if needed. var changed bool modTime, changed, err = fileChanged(s.cfgBase, modTime) if err != nil { @@ -132,7 +122,16 @@ func main() { continue } if changed { - reloadSig <- struct{}{} + select { + case reloadSig <- syscall.SIGHUP: + default: + } + } + + case <-reloadSig: + err := s.Reload(c) + if err != nil { + log.Warnf("couldn't reload servers: %s", err) } } } From 52fdc3f44b761f6c187cdadcfb005bafc49f0b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Mon, 18 May 2026 12:54:11 +0200 Subject: [PATCH 5/7] Errors channel can carry multiple errors --- main.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 3633fd3..acdf74c 100644 --- a/main.go +++ b/main.go @@ -68,6 +68,9 @@ func main() { }, }, Action: func(c *cli.Context) error { + ctx, ctxCancel := context.WithCancel(c.Context) + defer ctxCancel() + c.Context = ctx s, err := NewServer(c) if err != nil { @@ -109,7 +112,17 @@ func main() { for { select { case err := <-done: - return err + // Ensure we've started the shutdown sequence + ctxCancel() + + // All errors must be collected to ensure the shutdown sequence is complete + allErrs := []error{err} + for err = range done { + allErrs = append(allErrs, err) + } + + return errors.Join(allErrs...) + case <-timeChan: // Detect config file changes and reload config if needed. var changed bool From f0eb80879f5db9b9a9c8c5439e4b76a613b0258b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Mon, 18 May 2026 21:43:28 +0200 Subject: [PATCH 6/7] Extract runApp into a separate function This function is responsible for integration between the running server and external signals that could affect the server: termination and config reload. --- main.go | 40 ++++++++++++++++++++++++++-------------- server.go | 14 ++++++++++++-- server_test.go | 2 +- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/main.go b/main.go index acdf74c..702a563 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,26 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT, syscall.SIGTERM) defer stop() + reloadSig := make(chan os.Signal, 1) + signal.Notify(reloadSig, syscall.SIGHUP) + defer signal.Stop(reloadSig) + + err := runApp(ctx, os.Args, reloadSig, configCheckInterval, NewServer) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + os.Exit(0) +} + +func runApp( + ctx context.Context, + args []string, + reloadSig chan os.Signal, + configCheckInterval time.Duration, + newServer func(c *cli.Context) (serverInterface, error), +) error { app := &cli.App{ Name: "indexstar", Usage: "indexstar is a point in the content routing galaxy - routes requests in a star topology", @@ -72,14 +92,11 @@ func main() { defer ctxCancel() c.Context = ctx - s, err := NewServer(c) + s, err := newServer(c) if err != nil { return err } - reloadSig := make(chan os.Signal, 1) - signal.Notify(reloadSig, syscall.SIGHUP) - done := s.Serve() var ( @@ -89,7 +106,7 @@ func main() { timeChan <-chan time.Time ) if configCheckInterval != 0 { - cfgPath = s.cfgBase + cfgPath = s.GetCfgBase() if cfgPath == "" { cfgPath, err = Path("", "") if err != nil { @@ -126,12 +143,12 @@ func main() { case <-timeChan: // Detect config file changes and reload config if needed. var changed bool - modTime, changed, err = fileChanged(s.cfgBase, modTime) + modTime, changed, err = fileChanged(s.GetCfgBase(), modTime) if err != nil { log.Errorw("Cannot stat config file", "err", err, "path", cfgPath) ticker.Stop() ticker = nil - timeChan = nil // reading from nil channel blocks forever + timeChan = nil // disable timeChan from the select statement continue } if changed { @@ -150,13 +167,8 @@ func main() { } }, } - err := app.RunContext(ctx, os.Args) - if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) - } - - os.Exit(0) + err := app.RunContext(ctx, args) + return err } func fileChanged(filePath string, modTime time.Time) (time.Time, bool, error) { diff --git a/server.go b/server.go index 85a5e12..abb3bae 100644 --- a/server.go +++ b/server.go @@ -34,6 +34,12 @@ const ( providersBackendsArg = "providersBackends" ) +type serverInterface interface { + Serve() <-chan error + Reload(c *cli.Context) error + GetCfgBase() string +} + type server struct { context.Context http.Client @@ -62,7 +68,7 @@ type providersBackend struct { Backend } -func NewServer(c *cli.Context) (*server, error) { +func NewServer(c *cli.Context) (serverInterface, error) { var lc net.ListenConfig bound, err := lc.Listen(c.Context, "tcp", c.String("listen")) @@ -283,7 +289,7 @@ func (s *server) updateTopProviders() { } } -func (s *server) Serve() chan error { +func (s *server) Serve() <-chan error { mux := http.NewServeMux() mux.HandleFunc("/cid/", func(w http.ResponseWriter, r *http.Request) { s.findCid(w, r, false) }) mux.HandleFunc("/encrypted/cid/", func(w http.ResponseWriter, r *http.Request) { s.findCid(w, r, true) }) @@ -389,3 +395,7 @@ func writeJsonResponse(w http.ResponseWriter, status int, body []byte) { http.Error(w, "", http.StatusInternalServerError) } } + +func (s *server) GetCfgBase() string { + return s.cfgBase +} diff --git a/server_test.go b/server_test.go index 65188db..6d76748 100644 --- a/server_test.go +++ b/server_test.go @@ -34,7 +34,7 @@ type serverTestSuite struct { srv *server srvCancel context.CancelFunc - srvErrChan chan error + srvErrChan <-chan error } func TestServerTestSuite(t *testing.T) { From e2951baee839ee35ffe2d279e09dcba8cb385b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20=C5=9Awi=C4=99cki?= Date: Tue, 19 May 2026 12:01:43 +0200 Subject: [PATCH 7/7] Add test for runApp function --- main_test.go | 361 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 main_test.go diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..329275c --- /dev/null +++ b/main_test.go @@ -0,0 +1,361 @@ +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + "sync/atomic" + "syscall" + "testing" + "testing/synctest" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + cli "github.com/urfave/cli/v2" +) + +type mockServer struct { + ctx context.Context + + cfgBase string + errorsToEmitBeforeCancel []error + errorsToEmitAfterCancel []error + + reloadCallCnt atomic.Int64 +} + +func (s *mockServer) Serve() <-chan error { + errChan := make(chan error, 1) + + go func() { + defer close(errChan) + + for _, err := range s.errorsToEmitBeforeCancel { + errChan <- err + } + + <-s.ctx.Done() + + for _, err := range s.errorsToEmitAfterCancel { + errChan <- err + } + + }() + + return errChan +} + +func (s *mockServer) Reload(c *cli.Context) error { + s.reloadCallCnt.Add(1) + return nil +} + +func (s *mockServer) GetCfgBase() string { + return s.cfgBase +} + +type runAppTestSuite struct { + suite.Suite + errCnt int +} + +func (s *runAppTestSuite) genError() error { + s.errCnt++ + return fmt.Errorf("test error %d at %s", s.errCnt, time.Now()) +} + +func TestRunAppTestSuite(t *testing.T) { + suite.Run(t, &runAppTestSuite{}) +} + +func (s *runAppTestSuite) TestFailureInNewServer() { + synctest.Test(s.T(), func(t *testing.T) { + errToReturn := s.genError() + + err := runApp( + context.Background(), + []string{}, + nil, + configCheckInterval, + func(c *cli.Context) (serverInterface, error) { + return nil, errToReturn + }, + ) + + require.ErrorIs(t, err, errToReturn) + }) +} + +func (s *runAppTestSuite) TestSuccess() { + synctest.Test(s.T(), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + start := time.Now() + runAppFinished := false + + go func() { + err := runApp( + ctx, + []string{}, + nil, + time.Minute, // More than the duration of this test + func(c *cli.Context) (serverInterface, error) { + return &mockServer{ctx: c.Context}, nil + }, + ) + require.NoError(t, err) + require.Equal(t, time.Second, time.Since(start)) + runAppFinished = true + }() + + // Advance teh execution to a state where the server has started + time.Sleep(time.Second) + synctest.Wait() + require.False(t, runAppFinished) + + // Simulate terminating signal + cancel() + synctest.Wait() + require.True(t, runAppFinished) + }) +} + +func (s *runAppTestSuite) TestErrorBeforeCancel() { + synctest.Test(s.T(), func(t *testing.T) { + errToEmit1 := s.genError() + errToEmit2 := s.genError() + + err := runApp( + t.Context(), + []string{}, + nil, + 0, + func(c *cli.Context) (serverInterface, error) { + return &mockServer{ + errorsToEmitBeforeCancel: []error{ + errToEmit1, errToEmit2, + }, + ctx: c.Context, + }, nil + }, + ) + + require.ErrorIs(t, err, errToEmit1) + require.ErrorIs(t, err, errToEmit2) + }) +} + +func (s *runAppTestSuite) TestErrorAfterCancel() { + synctest.Test(s.T(), func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + errToEmit1 := s.genError() + errToEmit2 := s.genError() + + runAppFinished := false + + wg := sync.WaitGroup{} + + wg.Go(func() { + err := runApp( + ctx, + []string{}, + nil, + 0, + func(c *cli.Context) (serverInterface, error) { + return &mockServer{ + errorsToEmitAfterCancel: []error{ + errToEmit1, errToEmit2, + }, + ctx: c.Context, + }, nil + }, + ) + require.ErrorIs(t, err, errToEmit1) + require.ErrorIs(t, err, errToEmit2) + + runAppFinished = true + }) + + synctest.Wait() + require.False(t, runAppFinished) + + cancel() + + synctest.Wait() + require.True(t, runAppFinished) + + wg.Wait() + }) +} + +func (s *runAppTestSuite) TestErrorBeforeAndAfterCancel() { + synctest.Test(s.T(), func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + errToEmitBefore1 := s.genError() + errToEmitBefore2 := s.genError() + errToEmitAfter1 := s.genError() + errToEmitAfter2 := s.genError() + + runAppFinished := false + + wg := sync.WaitGroup{} + + wg.Go(func() { + err := runApp( + ctx, + []string{}, + nil, + 0, + func(c *cli.Context) (serverInterface, error) { + return &mockServer{ + errorsToEmitBeforeCancel: []error{ + errToEmitBefore1, errToEmitBefore2, + }, + errorsToEmitAfterCancel: []error{ + errToEmitAfter1, errToEmitAfter2, + }, + ctx: c.Context, + }, nil + }, + ) + require.ErrorIs(t, err, errToEmitBefore1) + require.ErrorIs(t, err, errToEmitBefore2) + require.ErrorIs(t, err, errToEmitAfter1) + require.ErrorIs(t, err, errToEmitAfter2) + + runAppFinished = true + }) + + synctest.Wait() + + cancel() + + synctest.Wait() + require.True(t, runAppFinished) + + wg.Wait() + }) +} + +func (s *runAppTestSuite) TestReloadConfigFromSignal() { + synctest.Test(s.T(), func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + cfgPath := filepath.Join(t.TempDir(), "config.txt") + + reloadSignal := make(chan os.Signal, 1) + + var srv *mockServer + + wg := sync.WaitGroup{} + wg.Go(func() { + err := runApp( + ctx, + []string{}, + reloadSignal, + 0, + func(c *cli.Context) (serverInterface, error) { + srv = &mockServer{ + ctx: c.Context, + cfgBase: cfgPath, + } + return srv, nil + }, + ) + require.NoError(t, err) + }) + + synctest.Wait() + require.Zero(t, srv.reloadCallCnt.Load()) + + reloadSignal <- syscall.SIGHUP + synctest.Wait() + require.EqualValues(t, 1, srv.reloadCallCnt.Load()) + + cancel() + wg.Wait() + }) +} + +func (s *runAppTestSuite) TestReloadConfigFromFileChange() { + synctest.Test(s.T(), func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + wg := sync.WaitGroup{} + + cfgPath := filepath.Join(t.TempDir(), "config.txt") + require.NoError(t, os.WriteFile(cfgPath, []byte("test1"), 0600)) + + reloadSignal := make(chan os.Signal, 1) + + var srv *mockServer + + wg.Go(func() { + err := runApp( + ctx, + []string{}, + reloadSignal, + configCheckInterval, + func(c *cli.Context) (serverInterface, error) { + srv = &mockServer{ + ctx: c.Context, + cfgBase: cfgPath, + } + return srv, nil + }, + ) + require.NoError(t, err) + }) + + // must not reload config initially + synctest.Wait() + require.Zero(t, srv.reloadCallCnt.Load()) + + // must not reload config if config file has not changed + time.Sleep(2 * configCheckInterval) + synctest.Wait() + require.Zero(t, srv.reloadCallCnt.Load()) + + // must reload config if config file has changed + require.NoError(t, os.WriteFile(cfgPath, []byte("updated file"), 0600)) + + time.Sleep(configCheckInterval) + synctest.Wait() + require.EqualValues(t, 1, srv.reloadCallCnt.Load()) + + // no changes afterwards, must not reload + time.Sleep(2 * configCheckInterval) + synctest.Wait() + require.EqualValues(t, 1, srv.reloadCallCnt.Load()) + + // stop reloading if the file is removed + plr := logging.NewPipeReader() + defer plr.Close() + + require.NoError(t, os.Remove(cfgPath)) + time.Sleep(2 * configCheckInterval) + synctest.Wait() + require.EqualValues(t, 1, srv.reloadCallCnt.Load()) + + buf := make([]byte, 1000) + n, err := plr.Read(buf) + require.NoError(t, err) + buf = buf[:n] + + require.Contains(t, string(buf), "Cannot stat config file") + + cancel() + wg.Wait() + }) +}