From c0fd82e848bb0f41a708f0dfea86a4d9c732f623 Mon Sep 17 00:00:00 2001 From: Devansh Das Date: Mon, 24 Nov 2025 16:32:33 +0100 Subject: [PATCH 1/2] feature: Add WIF Impersonation support for AWS --- cpp/AwsAttestation.cpp | 211 +++++++++++++++++++++++++++++------------ 1 file changed, 152 insertions(+), 59 deletions(-) diff --git a/cpp/AwsAttestation.cpp b/cpp/AwsAttestation.cpp index 956d4fb226..38bdfcdd5e 100644 --- a/cpp/AwsAttestation.cpp +++ b/cpp/AwsAttestation.cpp @@ -5,76 +5,169 @@ #include "logger/SFLogger.hpp" #include #include -#include -#include "snowflake/AWSUtils.hpp" +#include +#include +#include +#include -namespace Snowflake { - namespace Client { - - // We don't need x-amz-content-sha256 header, because there is no payload to be signed. - // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with - // "The AWS STS request contained unacceptable headers." - class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer - { - public: - AWSAuthV4SignerNoPayload(const std::shared_ptr& credentialsProvider, const char* serviceName, const Aws::String& region) - : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; } - }; - - boost::optional createAwsAttestation(const AttestationConfig& config) { - if (config.workloadIdentityImpersonationPath && - !config.workloadIdentityImpersonationPath.get().empty()) { - CXX_LOG_ERROR("Workload identity impersonation is not supported for AWS"); - return boost::none; +namespace Snowflake::Client { + // We don't need x-amz-content-sha256 header, because there is no payload to be signed. + // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with + // "The AWS STS request contained unacceptable headers." + class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer + { + public: + AWSAuthV4SignerNoPayload(const std::shared_ptr& credentialsProvider, const char* serviceName, const Aws::String& region) + : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; } + }; + + // Splits comma-separated role ARN impersonation path + std::vector parseRoleArnChain(const std::string &path) { + std::vector result; + std::stringstream ss(path); + std::string item; + + while (std::getline(ss, item, ',')) { + const auto start = item.find_first_not_of(" \t"); + const auto end = item.find_last_not_of(" \t"); + + if (start != std::string::npos) { + result.push_back(item.substr(start, end - start + 1)); } + } + + return result; + } + + // Assumes a single AWS role and returns temporary credentials + boost::optional assumeAwsRole( + const Aws::Auth::AWSCredentials ¤tCreds, + const std::string &roleArn) { + + CXX_LOG_DEBUG("Assuming AWS role: %s", roleArn.c_str()); + + const Aws::STS::STSClient stsClient(currentCreds); + + Aws::STS::Model::AssumeRoleRequest assumeRoleRequest; + assumeRoleRequest.SetRoleArn(roleArn.c_str()); + + const std::string sessionName = "snowflake-wif-" + std::string(Aws::Utils::UUID::PseudoRandomUUID()); + assumeRoleRequest.SetRoleSessionName(sessionName.c_str()); + assumeRoleRequest.SetDurationSeconds(3600); + + const auto outcome = stsClient.AssumeRole(assumeRoleRequest); + + if (!outcome.IsSuccess()) { + CXX_LOG_ERROR("Failed to assume role %s: %s", + roleArn.c_str(), + outcome.GetError().GetMessage().c_str()); + return boost::none; + } + + const auto &credentials = outcome.GetResult().GetCredentials(); + return Aws::Auth::AWSCredentials( + credentials.GetAccessKeyId(), + credentials.GetSecretAccessKey(), + credentials.GetSessionToken() + ); + } + + // Assumes a chain of AWS roles sequentially + boost::optional assumeAwsRoleChain( + const Aws::Auth::AWSCredentials &initialCreds, + const std::vector &roleArnChain) { + + if (roleArnChain.empty()) { + CXX_LOG_ERROR("Role ARN chain is empty"); + return boost::none; + } + + Aws::Auth::AWSCredentials currentCreds = initialCreds; - auto awsSdkInit = AwsUtils::initAwsSdk(); - auto creds = config.awsSdkWrapper->getCredentials(); - if (creds.IsEmpty()) { - CXX_LOG_INFO("Failed to get AWS credentials"); + for (const auto &roleArn: roleArnChain) { + auto assumedCredsOpt = assumeAwsRole(currentCreds, roleArn); + if (!assumedCredsOpt) { + CXX_LOG_ERROR("Failed to assume role in chain: %s", roleArn.c_str()); return boost::none; } + currentCreds = assumedCredsOpt.get(); + } + + return currentCreds; + } - auto regionOpt = config.awsSdkWrapper->getEC2Region(); - if (!regionOpt) { - CXX_LOG_INFO("Failed to get AWS region"); + boost::optional createAwsAttestation(const AttestationConfig &config) { + auto awsSdkInit = AwsUtils::initAwsSdk(); + auto creds = config.awsSdkWrapper->getCredentials(); + if (creds.IsEmpty()) { + CXX_LOG_INFO("Failed to get AWS credentials"); + return boost::none; + } + + auto regionOpt = config.awsSdkWrapper->getEC2Region(); + if (!regionOpt) { + CXX_LOG_INFO("Failed to get AWS region"); + return boost::none; + } + const std::string ®ion = regionOpt.get(); + + // Check if role assumption chain is configured + if (config.workloadIdentityImpersonationPath && + !config.workloadIdentityImpersonationPath.get().empty()) { + + CXX_LOG_INFO("Using AWS role assumption chain for impersonation"); + + auto roleArnChain = parseRoleArnChain(config.workloadIdentityImpersonationPath.get()); + + if (roleArnChain.empty()) { + CXX_LOG_ERROR("Failed to parse role ARN chain"); return boost::none; } - const std::string& region = regionOpt.get(); - const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region); - const std::string host = std::string("sts") + "." + region + "." + domain; - const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15"; - - auto request = Aws::Http::CreateHttpRequest( - Aws::String(url), - Aws::Http::HttpMethod::HTTP_POST, - Aws::Utils::Stream::DefaultResponseStreamFactoryMethod - ); - - request->SetHeaderValue("Host", host); - request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com"); - - auto simpleCredProvider = std::make_shared(creds); - AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region); - - // Sign the request - if (!signer.SignRequest(*request)) { - CXX_LOG_ERROR("Failed to sign request"); + + CXX_LOG_DEBUG("Role ARN chain size: %zu", roleArnChain.size()); + + auto assumedCredsOpt = assumeAwsRoleChain(creds, roleArnChain); + if (!assumedCredsOpt) { + CXX_LOG_ERROR("Failed to assume role chain"); return boost::none; } - picojson::object obj; - obj["url"] = picojson::value(request->GetURIString()); - obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod())); - picojson::object headers; - for (const auto &h: request->GetHeaders()) { - headers[h.first] = picojson::value(h.second); - } - obj["headers"] = picojson::value(headers); - std::string json = picojson::value(obj).serialize(true); - std::string base64; - Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64)); - return Attestation::makeAws(base64); + creds = assumedCredsOpt.get(); + } + + const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region); + const std::string host = std::string("sts") + "." + region + "." + domain; + const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15"; + + auto request = Aws::Http::CreateHttpRequest( + Aws::String(url), + Aws::Http::HttpMethod::HTTP_POST, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod + ); + + request->SetHeaderValue("Host", host); + request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com"); + + auto simpleCredProvider = std::make_shared(creds); + AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region); + + // Sign the request + if (!signer.SignRequest(*request)) { + CXX_LOG_ERROR("Failed to sign request"); + return boost::none; + } + + picojson::object obj; + obj["url"] = picojson::value(request->GetURIString()); + obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod())); + picojson::object headers; + for (const auto &h: request->GetHeaders()) { + headers[h.first] = picojson::value(h.second); } + obj["headers"] = picojson::value(headers); + std::string json = picojson::value(obj).serialize(true); + std::string base64; + Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64)); + return Attestation::makeAws(base64); } } \ No newline at end of file From 66ccd821f77deab0e090b51689fa1036bd060a19 Mon Sep 17 00:00:00 2001 From: Devansh Das Date: Mon, 24 Nov 2025 16:32:10 +0100 Subject: [PATCH 2/2] feature: WIF Impersonation tests --- tests/test_create_wif_attestation.cpp | 108 ++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tests/test_create_wif_attestation.cpp b/tests/test_create_wif_attestation.cpp index 99fc7baa0e..2d14502c99 100644 --- a/tests/test_create_wif_attestation.cpp +++ b/tests/test_create_wif_attestation.cpp @@ -191,6 +191,107 @@ void test_unit_aws_attestation_cred_missing(void **) { test_unit_aws_attestation_failed(&awsSdkWrapper); } +// These tests only verify the impersonation path parsing and configuration handling +// not the actual STS calls, which requires valid AWS setup. + +void test_unit_aws_attestation_impersonation_single_role(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + // Will fail at actual STS call, but validates path parsing and code path + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_role_chain(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + "arn:aws:iam::123456789012:role/Role1," + "arn:aws:iam::123456789012:role/Role2," + "arn:aws:iam::123456789012:role/Role3"; + + // Will fail at STS, but validates chain parsing + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_whitespace_handling(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + " arn:aws:iam::123456789012:role/Role1 , " + " arn:aws:iam::123456789012:role/Role2 "; + + // Will fail at STS, but validates whitespace trimming + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_cross_account(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + "arn:aws:iam::111111111111:role/SourceRole," + "arn:aws:iam::222222222222:role/TargetRole"; + + // Will fail at STS, but validates cross-account ARN parsing + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_empty_path_fallback(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = ""; + + auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); + + const auto& attestation = attestationOpt.get(); + assert_true(attestation.type == AttestationType::AWS); +} + +void test_unit_aws_attestation_impersonation_with_missing_region(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(boost::none, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_with_missing_credentials(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, Aws::Auth::AWSCredentials()); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + std::vector makeGCPToken(boost::optional issuer, boost::optional subject) { auto jwtObj = Jwt::JWTObject(); jwtObj.getHeader()->setAlgorithm(Jwt::AlgorithmType::RS256); @@ -872,6 +973,13 @@ int main() { cmocka_unit_test(test_unit_aws_attestation_china_region_success), cmocka_unit_test(test_unit_aws_attestation_region_missing), cmocka_unit_test(test_unit_aws_attestation_cred_missing), + cmocka_unit_test(test_unit_aws_attestation_impersonation_single_role), + cmocka_unit_test(test_unit_aws_attestation_impersonation_role_chain), + cmocka_unit_test(test_unit_aws_attestation_impersonation_whitespace_handling), + cmocka_unit_test(test_unit_aws_attestation_impersonation_cross_account), + cmocka_unit_test(test_unit_aws_attestation_impersonation_empty_path_fallback), + cmocka_unit_test(test_unit_aws_attestation_impersonation_with_missing_region), + cmocka_unit_test(test_unit_aws_attestation_impersonation_with_missing_credentials), cmocka_unit_test(test_unit_gcp_attestation_success), cmocka_unit_test(test_unit_gcp_attestation_missing_issuer), cmocka_unit_test(test_unit_gcp_attestation_missing_subject),