Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions internal/service/rdsdata/query_data_source.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package rdsdata

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/service/rdsdata"
rdsdatatypes "github.com/aws/aws-sdk-go-v2/service/rdsdata/types"
"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/datasource/schema"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-provider-aws/internal/framework"
"github.com/hashicorp/terraform-provider-aws/names"
)

// @FrameworkDataSource("aws_rdsdata_query", name="Query")
func newDataSourceQuery(context.Context) (datasource.DataSourceWithConfigure, error) {
return &dataSourceQuery{}, nil
}

type dataSourceQuery struct {
framework.DataSourceWithModel[dataSourceQueryModel]
}

func (d *dataSourceQuery) Schema(ctx context.Context, req datasource.SchemaRequest, resp *datasource.SchemaResponse) {
resp.Schema = schema.Schema{
Attributes: map[string]schema.Attribute{
names.AttrID: framework.IDAttribute(),
names.AttrDatabase: schema.StringAttribute{
Optional: true,
},
names.AttrResourceARN: schema.StringAttribute{
Required: true,
},
"secret_arn": schema.StringAttribute{
Required: true,
},
"sql": schema.StringAttribute{
Required: true,
},
"records": schema.StringAttribute{
Computed: true,
},
"number_of_records_updated": schema.Int64Attribute{
Computed: true,
},
},
Blocks: map[string]schema.Block{
names.AttrParameters: schema.ListNestedBlock{
NestedObject: schema.NestedBlockObject{
Attributes: map[string]schema.Attribute{
names.AttrName: schema.StringAttribute{
Required: true,
},
names.AttrValue: schema.StringAttribute{
Required: true,
},
"type_hint": schema.StringAttribute{
Optional: true,
},
},
},
},
},
}
}

type dataSourceQueryModel struct {
framework.WithRegionModel
ID types.String `tfsdk:"id"`
Database types.String `tfsdk:"database"`
ResourceARN types.String `tfsdk:"resource_arn"`
SecretARN types.String `tfsdk:"secret_arn"`
SQL types.String `tfsdk:"sql"`
Parameters []dataSourceQueryParameterModel `tfsdk:"parameters"`
Records types.String `tfsdk:"records"`
NumberOfRecordsUpdated types.Int64 `tfsdk:"number_of_records_updated"`
}

type dataSourceQueryParameterModel struct {
Name types.String `tfsdk:"name"`
Value types.String `tfsdk:"value"`
TypeHint types.String `tfsdk:"type_hint"`
}

func (d *dataSourceQuery) Read(ctx context.Context, req datasource.ReadRequest, resp *datasource.ReadResponse) {
var data dataSourceQueryModel
resp.Diagnostics.Append(req.Config.Get(ctx, &data)...)
if resp.Diagnostics.HasError() {
return
}

conn := d.Meta().RDSDataClient(ctx)

input := rdsdata.ExecuteStatementInput{
ResourceArn: data.ResourceARN.ValueStringPointer(),
SecretArn: data.SecretARN.ValueStringPointer(),
Sql: data.SQL.ValueStringPointer(),
FormatRecordsAs: rdsdatatypes.RecordsFormatTypeJson,
}

if !data.Database.IsNull() {
input.Database = data.Database.ValueStringPointer()
}

if len(data.Parameters) > 0 {
input.Parameters = expandSQLParameters(data.Parameters)
}

output, err := conn.ExecuteStatement(ctx, &input)
if err != nil {
resp.Diagnostics.AddError("executing RDS Data API statement", err.Error())
return
}

data.ID = types.StringValue(data.ResourceARN.ValueString() + ":" + data.SQL.ValueString())
data.Records = types.StringPointerValue(output.FormattedRecords)
data.NumberOfRecordsUpdated = types.Int64Value(output.NumberOfRecordsUpdated)

resp.Diagnostics.Append(resp.State.Set(ctx, &data)...)
}

func expandSQLParameters(tfList []dataSourceQueryParameterModel) []rdsdatatypes.SqlParameter {
if len(tfList) == 0 {
return nil
}

var apiObjects []rdsdatatypes.SqlParameter

for _, tfObj := range tfList {
apiObject := rdsdatatypes.SqlParameter{
Name: tfObj.Name.ValueStringPointer(),
}

if !tfObj.TypeHint.IsNull() {
apiObject.TypeHint = rdsdatatypes.TypeHint(tfObj.TypeHint.ValueString())
}

// Convert value to Field type
valueStr := tfObj.Value.ValueString()
var field rdsdatatypes.Field

// Try to parse as JSON first, otherwise treat as string
var jsonValue any
if err := json.Unmarshal([]byte(valueStr), &jsonValue); err == nil {
switch v := jsonValue.(type) {
case string:
field = &rdsdatatypes.FieldMemberStringValue{Value: v}
case float64:
field = &rdsdatatypes.FieldMemberDoubleValue{Value: v}
case bool:
field = &rdsdatatypes.FieldMemberBooleanValue{Value: v}
default:
field = &rdsdatatypes.FieldMemberStringValue{Value: valueStr}
}
} else {
field = &rdsdatatypes.FieldMemberStringValue{Value: valueStr}
}

apiObject.Value = field
apiObjects = append(apiObjects, apiObject)
}

return apiObjects
}
126 changes: 126 additions & 0 deletions internal/service/rdsdata/query_data_source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package rdsdata_test

import (
"fmt"
"testing"

"github.com/hashicorp/terraform-plugin-testing/helper/resource"
"github.com/hashicorp/terraform-provider-aws/internal/acctest"
"github.com/hashicorp/terraform-provider-aws/names"
)

func TestAccRDSDataQueryDataSource_basic(t *testing.T) {
ctx := acctest.Context(t)
dataSourceName := "data.aws_rdsdata_query.test"
rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.RDSServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
Steps: []resource.TestStep{
{
Config: testAccQueryDataSourceConfig_basic(rName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttrSet(dataSourceName, "records"),
resource.TestCheckResourceAttr(dataSourceName, "sql", "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA LIMIT 1"),
),
},
},
})
}

func TestAccRDSDataQueryDataSource_withParameters(t *testing.T) {
ctx := acctest.Context(t)
dataSourceName := "data.aws_rdsdata_query.test"
rName := acctest.RandomWithPrefix(t, acctest.ResourcePrefix)

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, names.RDSServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
Steps: []resource.TestStep{
{
Config: testAccQueryDataSourceConfig_withParameters(rName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttrSet(dataSourceName, "records"),
resource.TestCheckResourceAttr(dataSourceName, "sql", "SELECT :param1 as test_column"),
resource.TestCheckResourceAttr(dataSourceName, "parameters.#", "1"),
resource.TestCheckResourceAttr(dataSourceName, "parameters.0.name", "param1"),
resource.TestCheckResourceAttr(dataSourceName, "parameters.0.value", "test_value"),
),
},
},
})
}

func testAccQueryDataSourceConfig_basic(rName string) string {
return acctest.ConfigCompose(testAccQueryDataSourceConfig_base(rName), `
data "aws_rdsdata_query" "test" {
depends_on = [aws_rds_cluster_instance.test]
resource_arn = aws_rds_cluster.test.arn
secret_arn = aws_secretsmanager_secret.test.arn
sql = "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA LIMIT 1"
}
`)
}

func testAccQueryDataSourceConfig_withParameters(rName string) string {
return acctest.ConfigCompose(testAccQueryDataSourceConfig_base(rName), `
data "aws_rdsdata_query" "test" {
depends_on = [aws_rds_cluster_instance.test]
resource_arn = aws_rds_cluster.test.arn
secret_arn = aws_secretsmanager_secret.test.arn
sql = "SELECT :param1 as test_column"

parameters {
name = "param1"
value = "test_value"
}
}
`)
}

func testAccQueryDataSourceConfig_base(rName string) string {
return fmt.Sprintf(`
resource "aws_rds_cluster" "test" {
cluster_identifier = %[1]q
engine = "aurora-mysql"
database_name = "test"
master_username = "username"
master_password = "mustbeeightcharacters"
backup_retention_period = 7
preferred_backup_window = "07:00-09:00"
preferred_maintenance_window = "tue:04:00-tue:04:30"
skip_final_snapshot = true
enable_http_endpoint = true

serverlessv2_scaling_configuration {
max_capacity = 8
min_capacity = 0.5
}
}

resource "aws_rds_cluster_instance" "test" {
cluster_identifier = aws_rds_cluster.test.id
instance_class = "db.serverless"
engine = aws_rds_cluster.test.engine
engine_version = aws_rds_cluster.test.engine_version
}

resource "aws_secretsmanager_secret" "test" {
name = %[1]q
}

resource "aws_secretsmanager_secret_version" "test" {
secret_id = aws_secretsmanager_secret.test.id
secret_string = jsonencode({
username = aws_rds_cluster.test.master_username
password = aws_rds_cluster.test.master_password
})
}
`, rName)
}
10 changes: 9 additions & 1 deletion internal/service/rdsdata/service_package_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading