Skip to content

Commit 2993aec

Browse files
authored
Merge pull request #2796 from hashicorp/b-handle_any_identity_type
b/handle identity types that are not strings.
2 parents a3f80df + 5026ed4 commit 2993aec

File tree

4 files changed

+156
-52
lines changed

4 files changed

+156
-52
lines changed

internal/generic/resource.go

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -480,32 +480,10 @@ func (r *genericResource) Create(ctx context.Context, request resource.CreateReq
480480
pi = pi.AddRegionID()
481481
}
482482

483-
for _, v := range pi {
484-
if v.RequiredForImport {
485-
var out types.String
486-
response.Diagnostics.Append(response.State.GetAttribute(ctx, path.Root(v.Name), &out)...)
487-
if response.Diagnostics.HasError() {
488-
return
489-
}
490-
491-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(v.Name), out.ValueString())...)
492-
if response.Diagnostics.HasError() {
493-
return
494-
}
495-
} else {
496-
switch v.Name {
497-
case identity.NameAccountID:
498-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(identity.NameAccountID), r.provider.AccountID(ctx))...)
499-
if response.Diagnostics.HasError() {
500-
return
501-
}
502-
case identity.NameRegion:
503-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(identity.NameRegion), r.provider.Region(ctx))...)
504-
if response.Diagnostics.HasError() {
505-
return
506-
}
507-
}
508-
}
483+
d := pi.SetIdentity(ctx, r.provider, &response.State, response.Identity)
484+
response.Diagnostics.Append(d...)
485+
if response.Diagnostics.HasError() {
486+
return
509487
}
510488

511489
tflog.Debug(ctx, "Response.State.Raw", map[string]interface{}{
@@ -590,32 +568,10 @@ func (r *genericResource) Read(ctx context.Context, request resource.ReadRequest
590568
pi = pi.AddRegionID()
591569
}
592570

593-
for _, v := range pi {
594-
if v.RequiredForImport {
595-
var out types.String
596-
response.Diagnostics.Append(response.State.GetAttribute(ctx, path.Root(v.Name), &out)...)
597-
if response.Diagnostics.HasError() {
598-
return
599-
}
600-
601-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(v.Name), out.ValueString())...)
602-
if response.Diagnostics.HasError() {
603-
return
604-
}
605-
} else {
606-
switch v.Name {
607-
case identity.NameAccountID:
608-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(identity.NameAccountID), r.provider.AccountID(ctx))...)
609-
if response.Diagnostics.HasError() {
610-
return
611-
}
612-
case identity.NameRegion:
613-
response.Diagnostics.Append(response.Identity.SetAttribute(ctx, path.Root(identity.NameRegion), r.provider.Region(ctx))...)
614-
if response.Diagnostics.HasError() {
615-
return
616-
}
617-
}
618-
}
571+
d := pi.SetIdentity(ctx, r.provider, &response.State, response.Identity)
572+
response.Diagnostics.Append(d...)
573+
if response.Diagnostics.HasError() {
574+
return
619575
}
620576

621577
tflog.Debug(ctx, "Response.State.Raw", map[string]interface{}{

internal/identity/identifier.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33

44
package identity
55

6+
import (
7+
"context"
8+
9+
"github.com/hashicorp/terraform-plugin-framework/attr"
10+
"github.com/hashicorp/terraform-plugin-framework/diag"
11+
"github.com/hashicorp/terraform-plugin-framework/path"
12+
"github.com/hashicorp/terraform-provider-awscc/internal/service/cloudcontrol"
13+
)
14+
615
const (
716
NameAccountID = "account_id"
817
NameRegion = "region"
@@ -30,3 +39,42 @@ func (a Identifiers) AddRegionID() Identifiers {
3039
Description: "Region where this resource is managed",
3140
})
3241
}
42+
43+
type IdentitySetter interface {
44+
GetAttribute(context.Context, path.Path, any) diag.Diagnostics
45+
SetAttribute(context.Context, path.Path, any) diag.Diagnostics
46+
}
47+
48+
// SetIdentity sets the identity in state using the primary identifiers.
49+
func (a Identifiers) SetIdentity(ctx context.Context, provider cloudcontrol.Provider, state, identity IdentitySetter) diag.Diagnostics {
50+
var diags diag.Diagnostics
51+
for _, v := range a {
52+
if v.RequiredForImport {
53+
var out attr.Value
54+
diags.Append(state.GetAttribute(ctx, path.Root(v.Name), &out)...)
55+
if diags.HasError() {
56+
return diags
57+
}
58+
59+
diags.Append(identity.SetAttribute(ctx, path.Root(v.Name), ValueAsString(ctx, out))...)
60+
if diags.HasError() {
61+
return diags
62+
}
63+
} else {
64+
switch v.Name {
65+
case NameAccountID:
66+
diags.Append(identity.SetAttribute(ctx, path.Root(NameAccountID), provider.AccountID(ctx))...)
67+
if diags.HasError() {
68+
return diags
69+
}
70+
case NameRegion:
71+
diags.Append(identity.SetAttribute(ctx, path.Root(NameRegion), provider.Region(ctx))...)
72+
if diags.HasError() {
73+
return diags
74+
}
75+
}
76+
}
77+
}
78+
79+
return diags
80+
}

internal/identity/values.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package identity
5+
6+
import (
7+
"context"
8+
"fmt"
9+
10+
"github.com/hashicorp/terraform-plugin-framework/attr"
11+
"github.com/hashicorp/terraform-plugin-framework/types"
12+
)
13+
14+
func ValueAsString(ctx context.Context, v attr.Value) string {
15+
if v.IsNull() || v.IsUnknown() {
16+
return ""
17+
}
18+
19+
switch v.Type(ctx) {
20+
case types.StringType:
21+
return v.(types.String).ValueString()
22+
case types.Float64Type:
23+
return fmt.Sprintf("%v", v.(types.Float64).ValueFloat64())
24+
case types.Int64Type:
25+
return fmt.Sprintf("%d", v.(types.Int64).ValueInt64())
26+
case types.Int32Type:
27+
return fmt.Sprintf("%d", v.(types.Int32).ValueInt32())
28+
case types.NumberType:
29+
return fmt.Sprintf("%v", v.(types.Number).ValueBigFloat())
30+
case types.BoolType:
31+
return fmt.Sprintf("%t", v.(types.Bool).ValueBool())
32+
default:
33+
return ""
34+
}
35+
}

internal/identity/values_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package identity
5+
6+
import (
7+
"context"
8+
"math/big"
9+
"testing"
10+
11+
"github.com/hashicorp/terraform-plugin-framework/attr"
12+
"github.com/hashicorp/terraform-plugin-framework/types"
13+
)
14+
15+
func TestValueAsString(t *testing.T) {
16+
t.Parallel()
17+
18+
tests := []struct {
19+
name string
20+
input attr.Value
21+
expected string
22+
}{
23+
{
24+
name: "null string",
25+
input: types.StringNull(),
26+
expected: "",
27+
},
28+
{
29+
name: "valid string",
30+
input: types.StringValue("hello"),
31+
expected: "hello",
32+
},
33+
{
34+
name: "valid int64",
35+
input: types.Int64Value(1),
36+
expected: "1",
37+
},
38+
{
39+
name: "valid float64",
40+
input: types.Float64Value(3.14),
41+
expected: "3.14",
42+
},
43+
{
44+
name: "valid number",
45+
input: types.NumberValue(big.NewFloat(3)),
46+
expected: "3",
47+
},
48+
{
49+
name: "valid bool",
50+
input: types.BoolValue(true),
51+
expected: "true",
52+
},
53+
}
54+
55+
for _, tt := range tests {
56+
t.Run(tt.name, func(t *testing.T) {
57+
t.Parallel()
58+
59+
result := ValueAsString(context.TODO(), tt.input)
60+
if result != tt.expected {
61+
t.Fatalf("expected %q but got %q", tt.expected, result)
62+
}
63+
})
64+
}
65+
}

0 commit comments

Comments
 (0)