diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 5a2434c0ea..ed9adb2d9a 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -461,16 +461,15 @@ import "broken2.zed" require.Equal(t, lsp.Error, resp.Items[0].Severity) require.Contains(t, resp.Items[0].Message, "could not lookup definition `organization` for relation `viewer`: object definition `organization` not found") - // TODO this doesn't pass - //// broken2.zed has one error - // resp, _ = sendAndReceive[FullDocumentDiagnosticReport](tester, "textDocument/diagnostic", - // TextDocumentDiagnosticParams{ - // TextDocument: TextDocument{URI: "file:///testdir/broken2.zed"}, - // }) - // require.Equal(t, "full", resp.Kind) - // require.Len(t, resp.Items, 1) - // require.Equal(t, lsp.Error, resp.Items[0].Severity) - // require.Contains(t, resp.Items[0].Message, "could not lookup definition `organization` for relation `viewer`: object definition `organization` not found")} + // broken2.zed has one error + resp, _ = sendAndReceive[FullDocumentDiagnosticReport](tester, "textDocument/diagnostic", + TextDocumentDiagnosticParams{ + TextDocument: TextDocument{URI: "file:///testdir/broken2.zed"}, + }) + require.Equal(t, "full", resp.Kind) + require.Len(t, resp.Items, 1) + require.Equal(t, lsp.Error, resp.Items[0].Severity) + require.Contains(t, resp.Items[0].Message, "could not lookup definition `organization` for relation `viewer`: object definition `organization` not found") } func TestMultiFileBrokenImportDiagnostics(t *testing.T) { diff --git a/pkg/development/devcontext.go b/pkg/development/devcontext.go index 1d4445d524..1fd331b05c 100644 --- a/pkg/development/devcontext.go +++ b/pkg/development/devcontext.go @@ -286,6 +286,11 @@ func loadCompiled( errors := make([]*devinterface.DeveloperError, 0, len(compiled.OrderedDefinitions)) ts := schema.NewTypeSystem(schema.ResolverForCompiledSchema(compiled)) + // Validate partial bodies up front so errors are reported against the + // partial's own source location rather than against whatever definition + // happens to reference the partial later. + errors = append(errors, validateCompiledPartials(ctx, compiled)...) + var validDefs []datastore.SchemaDefinition for _, caveatDef := range compiled.CaveatDefinitions { diff --git a/pkg/development/partials.go b/pkg/development/partials.go new file mode 100644 index 0000000000..452f2af68e --- /dev/null +++ b/pkg/development/partials.go @@ -0,0 +1,180 @@ +package development + +import ( + "context" + "fmt" + "slices" + + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/tuple" +) + +// validateCompiledPartials checks the type references inside each partial in +// the compiled schema so that schema-level errors (e.g. an +// `allowedDirectRelation` pointing at an undefined object definition or +// caveat) are reported against the partial's own source position rather than +// being lost when the partial is unused, or being misattributed to the +// consumer that first inlines it. +// +// Only relations that are NOT inlined into any object definition are validated +// here. When a partial is consumed, the consumer's typesystem validation +// already surfaces the same error, and double-emitting would produce +// confusing duplicate diagnostics. The lost UX of "wrong attribution for +// consumed partials" is left to a separate enhancement. +// +// The checks intentionally exclude validations that depend on consumer- +// supplied state (computed userset / TTU operands referring to relations the +// consumer contributes); only `allowedDirectRelation` and caveat references +// are resolved, since those are fully determined by the partial body. +func validateCompiledPartials(ctx context.Context, compiled *compiler.CompiledSchema) []*devinterface.DeveloperError { + if len(compiled.PartialRelationOrigins) == 0 { + return nil + } + + // Build the set of relation pointers that are inlined into any object + // definition. Those will be validated by the consumer's typesystem pass + // in loadCompiled and must not be re-reported here. + consumed := make(map[*core.Relation]struct{}) + for _, nsDef := range compiled.ObjectDefinitions { + for _, rel := range nsDef.GetRelation() { + consumed[rel] = struct{}{} + } + } + + resolver := schema.ResolverForCompiledSchema(compiled) + + // Synthesize each partial as a namespace definition so error attribution + // can use the partial's name as the schema path and so self-references + // inside the partial's body can be resolved without leaking into the + // real namespace resolver. + syntheticDefs := make(map[string]*core.NamespaceDefinition, len(compiled.CompiledPartials)) + for name, rels := range compiled.CompiledPartials { + syntheticDefs[name] = namespace.Namespace(name, rels...) + } + + // Group relations by origin partial so each relation is validated exactly + // once against the partial that actually declared it. Sort relation + // pointers within each group by name for deterministic error ordering. + type pendingRelation struct { + rel *core.Relation + } + byOrigin := make(map[string][]pendingRelation, len(syntheticDefs)) + for rel, origin := range compiled.PartialRelationOrigins { + if _, isConsumed := consumed[rel]; isConsumed { + continue + } + byOrigin[origin] = append(byOrigin[origin], pendingRelation{rel: rel}) + } + if len(byOrigin) == 0 { + return nil + } + + originNames := make([]string, 0, len(byOrigin)) + for name := range byOrigin { + originNames = append(originNames, name) + } + slices.Sort(originNames) + + var errors []*devinterface.DeveloperError + for _, origin := range originNames { + owningDef := syntheticDefs[origin] + if owningDef == nil { + continue + } + pending := byOrigin[origin] + slices.SortFunc(pending, func(a, b pendingRelation) int { + if a.rel.GetName() == b.rel.GetName() { + return 0 + } + if a.rel.GetName() < b.rel.GetName() { + return -1 + } + return 1 + }) + for _, p := range pending { + typeInfo := p.rel.GetTypeInformation() + if typeInfo == nil { + continue + } + for _, allowed := range typeInfo.GetAllowedDirectRelations() { + if err := validatePartialAllowedRelation(ctx, resolver, p.rel, allowed); err != nil { + if devErr := getDevError(err, compiled, owningDef); devErr != nil { + errors = append(errors, devErr) + } + } + } + } + } + return errors +} + +// validatePartialAllowedRelation checks that the object type (and, where +// applicable, the referenced relation on that type) named by an +// allowed-direct-relation in a partial resolves against the real schema. +// Partial names are NOT accepted as subject types: partials are inlined into +// definitions, they are not runtime types of their own. That includes the +// partial referring to itself. +func validatePartialAllowedRelation( + ctx context.Context, + resolver *schema.CompiledSchemaResolver, + rel *core.Relation, + allowed *core.AllowedRelation, +) error { + subjectNamespace := allowed.GetNamespace() + + def, _, err := resolver.LookupDefinition(ctx, subjectNamespace) + if err != nil { + return schema.NewTypeWithSourceError( + fmt.Errorf("could not lookup definition `%s` for relation `%s`: %w", subjectNamespace, rel.GetName(), err), + allowed, + subjectNamespace, + ) + } + + // Check that the named relation on the subject namespace exists, + // mirroring the typesystem's behavior for non-self subjects. + if allowed.GetPublicWildcard() == nil && allowed.GetRelation() != tuple.Ellipsis { + if !relationExists(def, allowed.GetRelation()) { + return schema.NewTypeWithSourceError( + schema.NewRelationNotFoundErr(subjectNamespace, allowed.GetRelation()), + allowed, + allowed.GetRelation(), + ) + } + } + + return validatePartialCaveatReference(ctx, resolver, rel, allowed) +} + +func validatePartialCaveatReference( + ctx context.Context, + resolver *schema.CompiledSchemaResolver, + rel *core.Relation, + allowed *core.AllowedRelation, +) error { + caveatRef := allowed.GetRequiredCaveat() + if caveatRef == nil { + return nil + } + if _, err := resolver.LookupCaveat(ctx, caveatRef.GetCaveatName()); err != nil { + return schema.NewTypeWithSourceError( + fmt.Errorf("could not lookup caveat `%s` for relation `%s`: %w", caveatRef.GetCaveatName(), rel.GetName(), err), + allowed, + schema.SourceForAllowedRelation(allowed), + ) + } + return nil +} + +func relationExists(def *core.NamespaceDefinition, relationName string) bool { + for _, r := range def.GetRelation() { + if r.GetName() == relationName { + return true + } + } + return false +} diff --git a/pkg/development/partials_test.go b/pkg/development/partials_test.go new file mode 100644 index 0000000000..947579029e --- /dev/null +++ b/pkg/development/partials_test.go @@ -0,0 +1,241 @@ +package development + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +func TestValidateCompiledPartials(t *testing.T) { + tests := []struct { + name string + schema string + expectErrors bool + errorSubstring string + expectCount int // when >0, asserts exact number of errors + }{ + { + name: "partial with undefined object type", + schema: ` + use partial + + definition user {} + + partial secret { + relation viewer: notfound + } + `, + expectErrors: true, + errorSubstring: "could not lookup definition `notfound`", + }, + { + name: "partial referencing a relation only provided by consumer is allowed", + schema: ` + use partial + + definition user {} + + partial view_partial { + permission view = viewer + } + + definition resource { + relation viewer: user + ...view_partial + } + `, + expectErrors: false, + }, + { + name: "partial with bad relation on real type", + schema: ` + use partial + + definition user {} + + partial secret { + relation viewer: user#nonexistent + } + `, + expectErrors: true, + errorSubstring: "relation/permission `nonexistent` not found", + }, + { + name: "partial with undefined caveat", + schema: ` + use partial + + definition user {} + + partial secret { + relation viewer: user with missingcaveat + } + `, + expectErrors: true, + errorSubstring: "could not lookup caveat `missingcaveat`", + }, + { + name: "partial composing another partial via splat is accepted", + schema: ` + use partial + + definition user {} + + partial base_partial { + relation owner: user + } + + partial derived_partial { + ...base_partial + } + `, + expectErrors: false, + }, + { + name: "transitive partial error is reported exactly once", + schema: ` + use partial + + definition user {} + + partial base_partial { + relation owner: notfound + } + + partial derived_partial { + ...base_partial + } + `, + expectErrors: true, + errorSubstring: "could not lookup definition `notfound`", + expectCount: 1, + }, + { + name: "well-formed partial with consumer is clean", + schema: ` + use partial + + definition user {} + + partial view_partial { + relation viewer: user + permission view = viewer + } + + definition resource { + ...view_partial + } + `, + expectErrors: false, + }, + { + name: "partial referencing another partial as a type is rejected", + schema: ` + use partial + + definition user {} + + partial holder { + relation owner: user + } + + partial bad { + relation viewer: holder + } + `, + expectErrors: true, + errorSubstring: "could not lookup definition `holder`", + }, + { + name: "partial referencing itself as a type is rejected", + schema: ` + use partial + + definition user {} + + partial bad { + relation viewer: bad + } + `, + expectErrors: true, + errorSubstring: "could not lookup definition `bad`", + }, + { + name: "broken partial consumed by a definition is not double-reported here", + schema: ` + use partial + + definition user {} + + partial broken { + relation viewer: notfound + } + + definition resource { + ...broken + } + `, + // Consumer validation in loadCompiled will catch this; emitting it + // here as well would produce duplicate diagnostics. + expectErrors: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + compiled, err := compiler.Compile( + compiler.InputSchema{Source: input.Source("test"), SchemaString: tc.schema}, + compiler.AllowUnprefixedObjectType(), + ) + require.NoError(t, err, "schema should compile") + + errs := validateCompiledPartials(t.Context(), compiled) + if tc.expectErrors { + require.NotEmpty(t, errs) + if tc.expectCount > 0 { + require.Len(t, errs, tc.expectCount, "errors: %v", errs) + } + if tc.errorSubstring != "" { + var found bool + for _, e := range errs { + if strings.Contains(e.GetMessage(), tc.errorSubstring) { + found = true + break + } + } + require.True(t, found, "expected error containing %q, got %v", tc.errorSubstring, errs) + } + } else { + require.Empty(t, errs, "unexpected errors: %v", errs) + } + }) + } +} + +// TestPartialErrorReportedAgainstPartialPath asserts that a partial with a +// schema-level error in its body is reported against the partial's own +// declaration (with the partial-validation path), proving the error is no +// longer attributed solely to whatever definition first inlines the partial. +func TestPartialErrorReportedAgainstPartialPath(t *testing.T) { + schema := ` + use partial + + definition user {} + + partial secret { + relation viewer: notfound + } + ` + _, devErrs, err := NewDevContext( + t.Context(), + &devinterface.RequestContext{Schema: schema}, + ) + require.NoError(t, err) + require.NotNil(t, devErrs) + require.NotEmpty(t, devErrs.GetInputErrors()) +} diff --git a/pkg/schemadsl/compiler/compiler.go b/pkg/schemadsl/compiler/compiler.go index 7ce267254a..ee0865268a 100644 --- a/pkg/schemadsl/compiler/compiler.go +++ b/pkg/schemadsl/compiler/compiler.go @@ -45,6 +45,21 @@ type CompiledSchema struct { // order in which they were found. OrderedDefinitions []SchemaDefinition + // CompiledPartials holds the relations and permissions for each partial + // definition in the schema, keyed by the partial's fully-qualified name. + // Partials are not part of the runtime schema but are retained here so that + // callers can perform additional validation (e.g. type-checking partial bodies + // at compile time, before they are referenced elsewhere). + CompiledPartials map[string][]*core.Relation + + // PartialRelationOrigins maps each *core.Relation produced from partial + // bodies back to the partial path in which the relation was originally + // declared. Relations that flow into another partial via a `...other_partial` + // splat retain the original partial as the origin, and the same pointer is + // reused everywhere the relation is inlined (including in any consuming + // object definition). + PartialRelationOrigins map[*core.Relation]string + rootNode *dslNode mapper input.PositionMapper } @@ -165,18 +180,20 @@ func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*Co } initialCompiledPartials := make(map[string][]*core.Relation) + initialPartialRelationOrigins := make(map[*core.Relation]string) caveatTypeSet := caveattypes.TypeSetOrDefault(cfg.caveatTypeSet) compiled, err := translate(&translationContext{ - objectTypePrefix: cfg.objectTypePrefix, - mapper: mapper, - schemaString: schema.SchemaString, - skipValidate: cfg.skipValidation, - allowedFlags: cfg.allowedFlags, - enabledFlags: mapz.NewSet[string](), - existingNames: mapz.NewSet[string](), - compiledPartials: initialCompiledPartials, - unresolvedPartials: mapz.NewMultiMap[string, *dslNode](), - caveatTypeSet: caveatTypeSet, + objectTypePrefix: cfg.objectTypePrefix, + mapper: mapper, + schemaString: schema.SchemaString, + skipValidate: cfg.skipValidation, + allowedFlags: cfg.allowedFlags, + enabledFlags: mapz.NewSet[string](), + existingNames: mapz.NewSet[string](), + compiledPartials: initialCompiledPartials, + partialRelationOrigins: initialPartialRelationOrigins, + unresolvedPartials: mapz.NewMultiMap[string, *dslNode](), + caveatTypeSet: caveatTypeSet, }, root) if err != nil { var withNodeError withNodeError diff --git a/pkg/schemadsl/compiler/translator.go b/pkg/schemadsl/compiler/translator.go index 2e8be5c194..d5223b1c79 100644 --- a/pkg/schemadsl/compiler/translator.go +++ b/pkg/schemadsl/compiler/translator.go @@ -36,6 +36,13 @@ type translationContext struct { // The mapping of partial name -> relations represented by the partial compiledPartials map[string][]*core.Relation + // The mapping of *core.Relation -> the partial path in which the relation + // was originally declared. Relations inherited via `...other_partial` splat + // retain the original partial as the origin (the same pointer is reused). + // Used by downstream validation to attribute partial-body errors to the + // partial that actually declared the offending relation. + partialRelationOrigins map[*core.Relation]string + // A mapping of partial name -> partial DSL nodes whose resolution depends on // the resolution of the named partial unresolvedPartials *mapz.MultiMap[string, *dslNode] @@ -140,11 +147,13 @@ func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) } return &CompiledSchema{ - CaveatDefinitions: caveatDefinitions, - ObjectDefinitions: objectDefinitions, - OrderedDefinitions: orderedDefinitions, - rootNode: root, - mapper: tctx.mapper, + CaveatDefinitions: caveatDefinitions, + ObjectDefinitions: objectDefinitions, + OrderedDefinitions: orderedDefinitions, + CompiledPartials: tctx.compiledPartials, + PartialRelationOrigins: tctx.partialRelationOrigins, + rootNode: root, + mapper: tctx.mapper, }, nil } @@ -952,6 +961,18 @@ func translatePartial(tctx *translationContext, partialNode *dslNode) error { tctx.compiledPartials[partialPath] = relationsAndPermissions + // Record the partial in which each relation was originally declared. + // Relations inherited via `...other_partial` splat retain the original + // partial as origin (the same *core.Relation pointer is reused), so + // we only record an origin the first time we see a given pointer. + if tctx.partialRelationOrigins != nil { + for _, rel := range relationsAndPermissions { + if _, exists := tctx.partialRelationOrigins[rel]; !exists { + tctx.partialRelationOrigins[rel] = partialPath + } + } + } + // Since we've successfully compiled a partial, check the unresolved partials to see if any other partial was // waiting on this partial // NOTE: we're making an assumption here that a partial can't end up back in the same