Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ Pal provides several functions for registering services:

Pal also provides functions for retrieving services:

- `Invoke[T](ctx, invoker, args...)` - Retrieves or creates an instance of type `T` from the container, factory services may require argumens
- `Invoke[T](ctx, invoker, args...)` - Retrieves or creates an instance of type `T` from the container, factory services may require argumens.
- `InvokeAs[T, C](ctx, invoker, args...)` - A wrapper around `Inoke`, castes invoked service to specified `C`, returns an error if casging fails.
- `Build[S](ctx, invoker)` - Creates an instance of S, resolves its dependencies, injects them into its fields.
- `InjectInto[S](ctx, invoker, *S)` - Resolves S's dependencies and injects them into its fields.

All these functions accept nil as invoker, in this case, a Pal instance will be extracted from the context.
Pal automatilly adds itself into contexts paseed to `Init`, `Shutdown` and `Run` under `pal.CtxValue` key.
You can extract it manually with `pal.FromContext`

## Service Types

Pal supports several types of services, each designed for different use cases:
Expand Down
20 changes: 20 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ func ProvidePal(pal *Pal) *ServiceList {
}

// Invoke retrieves or creates an instance of type T from the given Pal container.
// Invoker may be nil, in this case an instance of Pal will be extracted from the context,
// if the context does not contain a Pal instance, an error will be returned.
func Invoke[T any](ctx context.Context, invoker Invoker, args ...any) (T, error) {
name := typetostring.GetType[T]()
if invoker == nil {
var err error
invoker, err = FromContext(ctx)
if err != nil {
return empty[T](), err
}
}

a, err := invoker.Invoke(ctx, name, args...)
if err != nil {
Expand All @@ -105,6 +114,8 @@ func MustInvoke[T any](ctx context.Context, invoker Invoker, args ...any) T {

// InvokeAs invokes a service and casts it to the expected type. It returns an error if the cast fails.
// May be useful when invoking a service with an interface type and you want to cast it to a concrete type.
// Invoker may be nil, in this case an instance of Pal will be extracted from the context,
// if the context does not contain a Pal instance, an error will be returned.
func InvokeAs[T any, C any](ctx context.Context, invoker Invoker, args ...any) (*C, error) {
service, err := Invoke[T](ctx, invoker, args...)
if err != nil {
Expand All @@ -128,6 +139,8 @@ func MustInvokeAs[T any, C any](ctx context.Context, invoker Invoker, args ...an
// Build resolves dependencies for a struct of type T using the provided context and Invoker.
// It initializes the struct's fields by injecting appropriate dependencies based on the field types.
// Returns the fully initialized struct or an error if dependency resolution fails.
// Invoker may be nil, in this case an instance of Pal will be extracted from the context,
// if the context does not contain a Pal instance, an error will be returned.
func Build[T any](ctx context.Context, invoker Invoker) (*T, error) {
s := new(T)

Expand All @@ -148,6 +161,13 @@ func MustBuild[T any](ctx context.Context, invoker Invoker) *T {
// It only sets fields that are exported and match a resolvable dependency, skipping fields when ErrServiceNotFound occurs.
// Returns an error if dependency invocation fails or other unrecoverable errors occur during injection.
func InjectInto[T any](ctx context.Context, invoker Invoker, s *T) error {
if invoker == nil {
var err error
invoker, err = FromContext(ctx)
if err != nil {
return err
}
}
return invoker.InjectInto(ctx, s)
}

Expand Down
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ var (

// ErrServiceInvalidCast is returned when a service is cast to a different type.
ErrServiceInvalidCast = errors.New("failed to cast service to the expected type")

// ErrInvokerIsNotInContext is returned when a context passed to Invoke does not contain a Pal instance.
ErrInvokerIsNotInContext = errors.New("invoker is not in context")
)
4 changes: 2 additions & 2 deletions hook_priority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestHookPriority_ToHealthCheck(t *testing.T) {
})

p := newPal(palService)
ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

// Initialize first
err := p.Init(t.Context())
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestHookPriority_MultipleHooks(t *testing.T) {
})

p := newPal(palService)
ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

// Initialize
err := p.Init(t.Context())
Expand Down
28 changes: 21 additions & 7 deletions pal.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,22 @@ func New(services ...ServiceDef) *Pal {
}

// FromContext retrieves a *Pal from the provided context, expecting it to be stored under the CtxValue key.
// Panics if ctx misses the value.
func FromContext(ctx context.Context) *Pal {
return ctx.Value(CtxValue).(*Pal)
func FromContext(ctx context.Context) (*Pal, error) {
invoker, ok := ctx.Value(CtxValue).(*Pal)
if !ok {
return nil, ErrInvokerIsNotInContext
}

return invoker, nil
}

// MustFromContext is like FromContext but panics if an error occurs.
func MustFromContext(ctx context.Context) *Pal {
return must(FromContext(ctx))
}

func WithPal(ctx context.Context, pal *Pal) context.Context {
return context.WithValue(ctx, CtxValue, pal)
}

// InitTimeout sets the timeout for the initialization of the services.
Expand Down Expand Up @@ -124,7 +137,7 @@ func (p *Pal) Init(ctx context.Context) error {
return nil
}

ctx = context.WithValue(ctx, CtxValue, p)
ctx = WithPal(ctx, p)

if err := p.config.Validate(ctx); err != nil {
return err
Expand Down Expand Up @@ -155,7 +168,7 @@ func (p *Pal) Run(ctx context.Context, signals ...os.Signal) error {
signals = DefaultShutdownSignals
}

ctx = context.WithValue(ctx, CtxValue, p)
ctx = WithPal(ctx, p)

ctx, stop := signal.NotifyContext(ctx, signals...)
defer stop()
Expand Down Expand Up @@ -196,6 +209,7 @@ func (p *Pal) Run(ctx context.Context, signals ...os.Signal) error {
}()

shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Duration(float64(p.config.ShutdownTimeout)*0.9))
shutdownCtx = WithPal(shutdownCtx, p)
defer cancel()

return errors.Join(runErr, p.container.Shutdown(shutdownCtx))
Expand All @@ -211,7 +225,7 @@ func (p *Pal) Services() map[string]ServiceDef {
// It implements the Invoker interface.
// The context is enriched with the Pal instance before being passed to the container.
func (p *Pal) Invoke(ctx context.Context, name string, args ...any) (any, error) {
ctx = context.WithValue(ctx, CtxValue, p)
ctx = WithPal(ctx, p)

return p.container.Invoke(ctx, name, args...)
}
Expand All @@ -220,7 +234,7 @@ func (p *Pal) Invoke(ctx context.Context, name string, args ...any) (any, error)
// It implements the Invoker interface.
// The context is enriched with the Pal instance before being passed to the container.
func (p *Pal) InjectInto(ctx context.Context, target any) error {
ctx = context.WithValue(ctx, CtxValue, p)
ctx = WithPal(ctx, p)

return p.container.InjectInto(ctx, target)
}
Expand Down
5 changes: 3 additions & 2 deletions pal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ func Test_FromContext(t *testing.T) {
t.Parallel()

p := newPal()
ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

result := pal.FromContext(ctx)
result, err := pal.FromContext(ctx)
assert.NoError(t, err)

assert.Same(t, p, result)
})
Expand Down
6 changes: 3 additions & 3 deletions service_const_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestService_Instance(t *testing.T) {
service := pal.Provide(NewMockRunnerServiceStruct(t))
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestService_ToInit(t *testing.T) {
})
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand All @@ -86,7 +86,7 @@ func TestService_ToInit(t *testing.T) {
service := pal.Provide[any](s)
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand Down
8 changes: 4 additions & 4 deletions service_factory1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestServiceFactory1_Instance(t *testing.T) {
})
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestServiceFactory1_Instance(t *testing.T) {
})
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand All @@ -75,7 +75,7 @@ func TestServiceFactory1_Instance(t *testing.T) {
})
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand All @@ -93,7 +93,7 @@ func TestServiceFactory1_Instance(t *testing.T) {
})
p := newPal(service)

ctx := context.WithValue(t.Context(), pal.CtxValue, p)
ctx := pal.WithPal(t.Context(), p)

err := p.Init(t.Context())
assert.NoError(t, err)
Expand Down