diff --git a/README.md b/README.md index 9b84b95..dd51d9d 100644 --- a/README.md +++ b/README.md @@ -103,11 +103,15 @@ Pal provides several functions for registering services: Pal also provides functions for retrieving services: -- `Invoke[T](ctx, invoker, args...)` - Retrieves or creates an instance of type `T` from the container, factory services may require argumens +- `Invoke[T](ctx, invoker, args...)` - Retrieves or creates an instance of type `T` from the container, factory services may require argumens. - `InvokeAs[T, C](ctx, invoker, args...)` - A wrapper around `Inoke`, castes invoked service to specified `C`, returns an error if casging fails. - `Build[S](ctx, invoker)` - Creates an instance of S, resolves its dependencies, injects them into its fields. - `InjectInto[S](ctx, invoker, *S)` - Resolves S's dependencies and injects them into its fields. +All these functions accept nil as invoker, in this case, a Pal instance will be extracted from the context. +Pal automatilly adds itself into contexts paseed to `Init`, `Shutdown` and `Run` under `pal.CtxValue` key. +You can extract it manually with `pal.FromContext` + ## Service Types Pal supports several types of services, each designed for different use cases: diff --git a/api.go b/api.go index 9d6d1ea..0f09129 100644 --- a/api.go +++ b/api.go @@ -82,8 +82,17 @@ func ProvidePal(pal *Pal) *ServiceList { } // Invoke retrieves or creates an instance of type T from the given Pal container. +// Invoker may be nil, in this case an instance of Pal will be extracted from the context, +// if the context does not contain a Pal instance, an error will be returned. func Invoke[T any](ctx context.Context, invoker Invoker, args ...any) (T, error) { name := typetostring.GetType[T]() + if invoker == nil { + var err error + invoker, err = FromContext(ctx) + if err != nil { + return empty[T](), err + } + } a, err := invoker.Invoke(ctx, name, args...) if err != nil { @@ -105,6 +114,8 @@ func MustInvoke[T any](ctx context.Context, invoker Invoker, args ...any) T { // InvokeAs invokes a service and casts it to the expected type. It returns an error if the cast fails. // May be useful when invoking a service with an interface type and you want to cast it to a concrete type. +// Invoker may be nil, in this case an instance of Pal will be extracted from the context, +// if the context does not contain a Pal instance, an error will be returned. func InvokeAs[T any, C any](ctx context.Context, invoker Invoker, args ...any) (*C, error) { service, err := Invoke[T](ctx, invoker, args...) if err != nil { @@ -128,6 +139,8 @@ func MustInvokeAs[T any, C any](ctx context.Context, invoker Invoker, args ...an // Build resolves dependencies for a struct of type T using the provided context and Invoker. // It initializes the struct's fields by injecting appropriate dependencies based on the field types. // Returns the fully initialized struct or an error if dependency resolution fails. +// Invoker may be nil, in this case an instance of Pal will be extracted from the context, +// if the context does not contain a Pal instance, an error will be returned. func Build[T any](ctx context.Context, invoker Invoker) (*T, error) { s := new(T) @@ -148,6 +161,13 @@ func MustBuild[T any](ctx context.Context, invoker Invoker) *T { // It only sets fields that are exported and match a resolvable dependency, skipping fields when ErrServiceNotFound occurs. // Returns an error if dependency invocation fails or other unrecoverable errors occur during injection. func InjectInto[T any](ctx context.Context, invoker Invoker, s *T) error { + if invoker == nil { + var err error + invoker, err = FromContext(ctx) + if err != nil { + return err + } + } return invoker.InjectInto(ctx, s) } diff --git a/errors.go b/errors.go index 6abccf3..9dbf2bc 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,7 @@ var ( // ErrServiceInvalidCast is returned when a service is cast to a different type. ErrServiceInvalidCast = errors.New("failed to cast service to the expected type") + + // ErrInvokerIsNotInContext is returned when a context passed to Invoke does not contain a Pal instance. + ErrInvokerIsNotInContext = errors.New("invoker is not in context") ) diff --git a/hook_priority_test.go b/hook_priority_test.go index 8086467..e7832e7 100644 --- a/hook_priority_test.go +++ b/hook_priority_test.go @@ -119,7 +119,7 @@ func TestHookPriority_ToHealthCheck(t *testing.T) { }) p := newPal(palService) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) // Initialize first err := p.Init(t.Context()) @@ -280,7 +280,7 @@ func TestHookPriority_MultipleHooks(t *testing.T) { }) p := newPal(palService) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) // Initialize err := p.Init(t.Context()) diff --git a/pal.go b/pal.go index abb5e9d..a01179f 100644 --- a/pal.go +++ b/pal.go @@ -58,9 +58,22 @@ func New(services ...ServiceDef) *Pal { } // FromContext retrieves a *Pal from the provided context, expecting it to be stored under the CtxValue key. -// Panics if ctx misses the value. -func FromContext(ctx context.Context) *Pal { - return ctx.Value(CtxValue).(*Pal) +func FromContext(ctx context.Context) (*Pal, error) { + invoker, ok := ctx.Value(CtxValue).(*Pal) + if !ok { + return nil, ErrInvokerIsNotInContext + } + + return invoker, nil +} + +// MustFromContext is like FromContext but panics if an error occurs. +func MustFromContext(ctx context.Context) *Pal { + return must(FromContext(ctx)) +} + +func WithPal(ctx context.Context, pal *Pal) context.Context { + return context.WithValue(ctx, CtxValue, pal) } // InitTimeout sets the timeout for the initialization of the services. @@ -124,7 +137,7 @@ func (p *Pal) Init(ctx context.Context) error { return nil } - ctx = context.WithValue(ctx, CtxValue, p) + ctx = WithPal(ctx, p) if err := p.config.Validate(ctx); err != nil { return err @@ -155,7 +168,7 @@ func (p *Pal) Run(ctx context.Context, signals ...os.Signal) error { signals = DefaultShutdownSignals } - ctx = context.WithValue(ctx, CtxValue, p) + ctx = WithPal(ctx, p) ctx, stop := signal.NotifyContext(ctx, signals...) defer stop() @@ -196,6 +209,7 @@ func (p *Pal) Run(ctx context.Context, signals ...os.Signal) error { }() shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(float64(p.config.ShutdownTimeout)*0.9)) + shutdownCtx = WithPal(shutdownCtx, p) defer cancel() return errors.Join(runErr, p.container.Shutdown(shutdownCtx)) @@ -211,7 +225,7 @@ func (p *Pal) Services() map[string]ServiceDef { // It implements the Invoker interface. // The context is enriched with the Pal instance before being passed to the container. func (p *Pal) Invoke(ctx context.Context, name string, args ...any) (any, error) { - ctx = context.WithValue(ctx, CtxValue, p) + ctx = WithPal(ctx, p) return p.container.Invoke(ctx, name, args...) } @@ -220,7 +234,7 @@ func (p *Pal) Invoke(ctx context.Context, name string, args ...any) (any, error) // It implements the Invoker interface. // The context is enriched with the Pal instance before being passed to the container. func (p *Pal) InjectInto(ctx context.Context, target any) error { - ctx = context.WithValue(ctx, CtxValue, p) + ctx = WithPal(ctx, p) return p.container.InjectInto(ctx, target) } diff --git a/pal_test.go b/pal_test.go index 035c280..2fd5fb4 100644 --- a/pal_test.go +++ b/pal_test.go @@ -70,9 +70,10 @@ func Test_FromContext(t *testing.T) { t.Parallel() p := newPal() - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) - result := pal.FromContext(ctx) + result, err := pal.FromContext(ctx) + assert.NoError(t, err) assert.Same(t, p, result) }) diff --git a/service_const_test.go b/service_const_test.go index 1068bd8..d950dee 100644 --- a/service_const_test.go +++ b/service_const_test.go @@ -32,7 +32,7 @@ func TestService_Instance(t *testing.T) { service := pal.Provide(NewMockRunnerServiceStruct(t)) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) @@ -65,7 +65,7 @@ func TestService_ToInit(t *testing.T) { }) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) @@ -86,7 +86,7 @@ func TestService_ToInit(t *testing.T) { service := pal.Provide[any](s) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) diff --git a/service_factory1_test.go b/service_factory1_test.go index e4160e0..41540b6 100644 --- a/service_factory1_test.go +++ b/service_factory1_test.go @@ -28,7 +28,7 @@ func TestServiceFactory1_Instance(t *testing.T) { }) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) @@ -57,7 +57,7 @@ func TestServiceFactory1_Instance(t *testing.T) { }) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) @@ -75,7 +75,7 @@ func TestServiceFactory1_Instance(t *testing.T) { }) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err) @@ -93,7 +93,7 @@ func TestServiceFactory1_Instance(t *testing.T) { }) p := newPal(service) - ctx := context.WithValue(t.Context(), pal.CtxValue, p) + ctx := pal.WithPal(t.Context(), p) err := p.Init(t.Context()) assert.NoError(t, err)