Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 53 additions & 35 deletions internal/provider/framework/identity_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"

"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/tfsdk"
Expand Down Expand Up @@ -58,41 +59,49 @@ func (r identityInterceptor) create(ctx context.Context, opts interceptorOptions
}
}
}

case OnError:
identity := response.Identity
if identity == nil {
break
}

if identityIsFullyNull(ctx, identity, r.attributes) {
for _, att := range r.attributes {
switch att.Name() {
case names.AttrAccountID:
opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if opts.response.Diagnostics.HasError() {
return
}
var diags diag.Diagnostics

case names.AttrRegion:
opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if opts.response.Diagnostics.HasError() {
return
}
identityLoop:
for _, att := range r.attributes {
switch att.Name() {
case names.AttrAccountID:
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if diags.HasError() {
break identityLoop
}

default:
var attrVal attr.Value
opts.response.Diagnostics.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if opts.response.Diagnostics.HasError() {
return
}
case names.AttrRegion:
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if diags.HasError() {
break identityLoop
}

opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if opts.response.Diagnostics.HasError() {
return
}
default:
var attrVal attr.Value
diags.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if diags.HasError() {
break identityLoop
}

diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if diags.HasError() {
break identityLoop
}
}
}

if diags.HasError() {
response.Identity = nil
}

opts.response.Diagnostics.Append(diags...)
}
}

Expand Down Expand Up @@ -189,32 +198,41 @@ func (r identityInterceptor) update(ctx context.Context, opts interceptorOptions
}

if identityIsFullyNull(ctx, identity, r.attributes) {
var diags diag.Diagnostics

identityLoop:
for _, att := range r.attributes {
switch att.Name() {
case names.AttrAccountID:
opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if opts.response.Diagnostics.HasError() {
return
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if diags.HasError() {
break identityLoop
}

case names.AttrRegion:
opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if opts.response.Diagnostics.HasError() {
return
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if diags.HasError() {
break identityLoop
}

default:
var attrVal attr.Value
opts.response.Diagnostics.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if opts.response.Diagnostics.HasError() {
return
diags.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if diags.HasError() {
break identityLoop
}

opts.response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if opts.response.Diagnostics.HasError() {
return
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if diags.HasError() {
break identityLoop
}
}

if diags.HasError() {
response.Identity = nil
}

opts.response.Diagnostics.Append(diags...)
}
}
}
Expand Down
150 changes: 145 additions & 5 deletions internal/provider/framework/identity_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ func create(ctx context.Context, interceptor identityInterceptor, resourceSchema
}

interceptor.create(ctx, opts)
if response.Diagnostics.HasError() {
return nil, response.Diagnostics
}

return response.Identity, response.Diagnostics
}

Expand All @@ -157,9 +155,151 @@ func read(ctx context.Context, interceptor identityInterceptor, resourceSchema s
}

interceptor.read(ctx, opts)
if response.Diagnostics.HasError() {
return nil, response.Diagnostics

return response.Identity, response.Diagnostics
}

func TestIdentityInterceptor_OnError(t *testing.T) {
t.Parallel()

accountID := "123456789012"
region := "us-west-2" //lintignore:AWSAT003
name := "a_name"

resourceSchema := schema.Schema{
Attributes: map[string]schema.Attribute{
"name": schema.StringAttribute{
Required: true,
},
"region": resourceattribute.Region(),
"type": schema.StringAttribute{
Optional: true,
},
},
}

client := mockClient{
accountID: accountID,
region: region,
}

testOperations := map[string]struct {
operation func(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics)
stateAttrs map[string]string
}{
"create": {
operation: createOnError,
stateAttrs: map[string]string{
"name": name,
"region": region,
"type": "some_type",
},
},
"update": {
operation: updateOnError,
stateAttrs: map[string]string{
"name": name,
"region": region,
"type": "some_type",
},
},
}

for tname, tc := range testOperations {
t.Run(tname, func(t *testing.T) {
t.Parallel()

operation := tc.operation
stateAttrs := tc.stateAttrs

testCases := map[string]struct {
attrName string
identitySpec inttypes.Identity
}{
"same names": {
attrName: "name",
identitySpec: regionalSingleParameterIdentitySpec("name"),
},
"name mapped": {
attrName: "resource_name",
identitySpec: regionalSingleParameterIdentitySpecNameMapped("resource_name", "name"),
},
}

for tname, tc := range testCases {
t.Run(tname, func(t *testing.T) {
t.Parallel()
ctx := t.Context()

identitySchema := identity.NewIdentitySchema(tc.identitySpec)

interceptor := newIdentityInterceptor(tc.identitySpec.Attributes)

identity := emtpyIdentityFromSchema(ctx, &identitySchema)

responseIdentity, _ := operation(ctx, interceptor, resourceSchema, stateAttrs, identity, client)

if e, a := accountID, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root("account_id")); e != a {
t.Errorf("expected Identity `account_id` to be %q, got %q", e, a)
}
if e, a := region, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root("region")); e != a {
t.Errorf("expected Identity `region` to be %q, got %q", e, a)
}
if e, a := name, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root(tc.attrName)); e != a {
t.Errorf("expected Identity `%s` to be %q, got %q", tc.attrName, e, a)
}
})
}
})
}
}

func createOnError(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics) {
request := resource.CreateRequest{
Config: configFromSchema(ctx, resourceSchema, stateAttrs),
Plan: planFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
}
response := resource.CreateResponse{
State: stateFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
Diagnostics: diag.Diagnostics{
diag.NewErrorDiagnostic("summary", "detail"),
},
}
opts := interceptorOptions[resource.CreateRequest, resource.CreateResponse]{
c: client,
request: &request,
response: &response,
when: OnError,
}

interceptor.create(ctx, opts)

return response.Identity, response.Diagnostics
}

func updateOnError(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics) {
request := resource.UpdateRequest{
State: stateFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
}
response := resource.UpdateResponse{
State: stateFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
Diagnostics: diag.Diagnostics{
diag.NewErrorDiagnostic("summary", "detail"),
},
}
opts := interceptorOptions[resource.UpdateRequest, resource.UpdateResponse]{
c: client,
request: &request,
response: &response,
when: OnError,
}

interceptor.update(ctx, opts)

return response.Identity, response.Diagnostics
}

Expand Down
Loading