diff --git a/cpp/AwsAttestation.cpp b/cpp/AwsAttestation.cpp index 6be31208c5..38bdfcdd5e 100644 --- a/cpp/AwsAttestation.cpp +++ b/cpp/AwsAttestation.cpp @@ -5,70 +5,169 @@ #include "logger/SFLogger.hpp" #include #include -#include -#include "snowflake/AWSUtils.hpp" +#include +#include +#include +#include -namespace Snowflake { - namespace Client { - - // We don't need x-amz-content-sha256 header, because there is no payload to be signed. - // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with - // "The AWS STS request contained unacceptable headers." - class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer - { - public: - AWSAuthV4SignerNoPayload(const std::shared_ptr& credentialsProvider, const char* serviceName, const Aws::String& region) - : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; } - }; - - boost::optional createAwsAttestation(const AttestationConfig& config) { - auto awsSdkInit = AwsUtils::initAwsSdk(); - auto creds = config.awsSdkWrapper->getCredentials(); - if (creds.IsEmpty()) { - CXX_LOG_INFO("Failed to get AWS credentials"); - return boost::none; +namespace Snowflake::Client { + // We don't need x-amz-content-sha256 header, because there is no payload to be signed. + // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with + // "The AWS STS request contained unacceptable headers." + class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer + { + public: + AWSAuthV4SignerNoPayload(const std::shared_ptr& credentialsProvider, const char* serviceName, const Aws::String& region) + : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; } + }; + + // Splits comma-separated role ARN impersonation path + std::vector parseRoleArnChain(const std::string &path) { + std::vector result; + std::stringstream ss(path); + std::string item; + + while (std::getline(ss, item, ',')) { + const auto start = item.find_first_not_of(" \t"); + const auto end = item.find_last_not_of(" \t"); + + if (start != std::string::npos) { + result.push_back(item.substr(start, end - start + 1)); } + } + + return result; + } + + // Assumes a single AWS role and returns temporary credentials + boost::optional assumeAwsRole( + const Aws::Auth::AWSCredentials ¤tCreds, + const std::string &roleArn) { + + CXX_LOG_DEBUG("Assuming AWS role: %s", roleArn.c_str()); + + const Aws::STS::STSClient stsClient(currentCreds); + + Aws::STS::Model::AssumeRoleRequest assumeRoleRequest; + assumeRoleRequest.SetRoleArn(roleArn.c_str()); + + const std::string sessionName = "snowflake-wif-" + std::string(Aws::Utils::UUID::PseudoRandomUUID()); + assumeRoleRequest.SetRoleSessionName(sessionName.c_str()); + assumeRoleRequest.SetDurationSeconds(3600); + + const auto outcome = stsClient.AssumeRole(assumeRoleRequest); + + if (!outcome.IsSuccess()) { + CXX_LOG_ERROR("Failed to assume role %s: %s", + roleArn.c_str(), + outcome.GetError().GetMessage().c_str()); + return boost::none; + } + + const auto &credentials = outcome.GetResult().GetCredentials(); + return Aws::Auth::AWSCredentials( + credentials.GetAccessKeyId(), + credentials.GetSecretAccessKey(), + credentials.GetSessionToken() + ); + } + + // Assumes a chain of AWS roles sequentially + boost::optional assumeAwsRoleChain( + const Aws::Auth::AWSCredentials &initialCreds, + const std::vector &roleArnChain) { + + if (roleArnChain.empty()) { + CXX_LOG_ERROR("Role ARN chain is empty"); + return boost::none; + } + + Aws::Auth::AWSCredentials currentCreds = initialCreds; - auto regionOpt = config.awsSdkWrapper->getEC2Region(); - if (!regionOpt) { - CXX_LOG_INFO("Failed to get AWS region"); + for (const auto &roleArn: roleArnChain) { + auto assumedCredsOpt = assumeAwsRole(currentCreds, roleArn); + if (!assumedCredsOpt) { + CXX_LOG_ERROR("Failed to assume role in chain: %s", roleArn.c_str()); return boost::none; } - const std::string& region = regionOpt.get(); - const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region); - const std::string host = std::string("sts") + "." + region + "." + domain; - const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15"; - - auto request = Aws::Http::CreateHttpRequest( - Aws::String(url), - Aws::Http::HttpMethod::HTTP_POST, - Aws::Utils::Stream::DefaultResponseStreamFactoryMethod - ); - - request->SetHeaderValue("Host", host); - request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com"); - - auto simpleCredProvider = std::make_shared(creds); - AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region); - - // Sign the request - if (!signer.SignRequest(*request)) { - CXX_LOG_ERROR("Failed to sign request"); + currentCreds = assumedCredsOpt.get(); + } + + return currentCreds; + } + + boost::optional createAwsAttestation(const AttestationConfig &config) { + auto awsSdkInit = AwsUtils::initAwsSdk(); + auto creds = config.awsSdkWrapper->getCredentials(); + if (creds.IsEmpty()) { + CXX_LOG_INFO("Failed to get AWS credentials"); + return boost::none; + } + + auto regionOpt = config.awsSdkWrapper->getEC2Region(); + if (!regionOpt) { + CXX_LOG_INFO("Failed to get AWS region"); + return boost::none; + } + const std::string ®ion = regionOpt.get(); + + // Check if role assumption chain is configured + if (config.workloadIdentityImpersonationPath && + !config.workloadIdentityImpersonationPath.get().empty()) { + + CXX_LOG_INFO("Using AWS role assumption chain for impersonation"); + + auto roleArnChain = parseRoleArnChain(config.workloadIdentityImpersonationPath.get()); + + if (roleArnChain.empty()) { + CXX_LOG_ERROR("Failed to parse role ARN chain"); return boost::none; } - picojson::object obj; - obj["url"] = picojson::value(request->GetURIString()); - obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod())); - picojson::object headers; - for (const auto &h: request->GetHeaders()) { - headers[h.first] = picojson::value(h.second); + CXX_LOG_DEBUG("Role ARN chain size: %zu", roleArnChain.size()); + + auto assumedCredsOpt = assumeAwsRoleChain(creds, roleArnChain); + if (!assumedCredsOpt) { + CXX_LOG_ERROR("Failed to assume role chain"); + return boost::none; } - obj["headers"] = picojson::value(headers); - std::string json = picojson::value(obj).serialize(true); - std::string base64; - Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64)); - return Attestation::makeAws(base64); + + creds = assumedCredsOpt.get(); + } + + const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region); + const std::string host = std::string("sts") + "." + region + "." + domain; + const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15"; + + auto request = Aws::Http::CreateHttpRequest( + Aws::String(url), + Aws::Http::HttpMethod::HTTP_POST, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod + ); + + request->SetHeaderValue("Host", host); + request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com"); + + auto simpleCredProvider = std::make_shared(creds); + AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region); + + // Sign the request + if (!signer.SignRequest(*request)) { + CXX_LOG_ERROR("Failed to sign request"); + return boost::none; + } + + picojson::object obj; + obj["url"] = picojson::value(request->GetURIString()); + obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod())); + picojson::object headers; + for (const auto &h: request->GetHeaders()) { + headers[h.first] = picojson::value(h.second); } + obj["headers"] = picojson::value(headers); + std::string json = picojson::value(obj).serialize(true); + std::string base64; + Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64)); + return Attestation::makeAws(base64); } } \ No newline at end of file diff --git a/cpp/AzureAttestation.cpp b/cpp/AzureAttestation.cpp index 902241f739..a029e4aade 100644 --- a/cpp/AzureAttestation.cpp +++ b/cpp/AzureAttestation.cpp @@ -13,6 +13,12 @@ namespace { namespace Snowflake { namespace Client { boost::optional createAzureAttestation(AttestationConfig& config) { + if (config.workloadIdentityImpersonationPath && + !config.workloadIdentityImpersonationPath.get().empty()) { + CXX_LOG_ERROR("Workload identity impersonation is not supported for Azure"); + return boost::none; + } + auto azureConfigOpt = AzureAttestationConfig::fromConfig(config); if (!azureConfigOpt) { return boost::none; diff --git a/cpp/GcpAttestation.cpp b/cpp/GcpAttestation.cpp index d6a17811b1..6550f5a7e8 100644 --- a/cpp/GcpAttestation.cpp +++ b/cpp/GcpAttestation.cpp @@ -1,25 +1,206 @@ - #include "GcpAttestation.hpp" #include "jwt/Jwt.hpp" #include "snowflake/HttpClient.hpp" #include "logger/SFLogger.hpp" +#include +#include +#include + +namespace Snowflake::Client { + constexpr auto SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"; + constexpr auto GCP_METADATA_SERVER_BASE_URL = "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/"; + constexpr auto GCP_IAM_CREDENTIALS_BASE_URL = "https://iamcredentials.googleapis.com/v1"; + + // Splits comma-separated impersonation path + std::vector parseImpersonationPath(const std::string &path) { + std::vector result; + std::stringstream ss(path); + std::string item; + + while (std::getline(ss, item, ',')) { + const auto start = item.find_first_not_of(" \t"); + const auto end = item.find_last_not_of(" \t"); + + if (start != std::string::npos) { + result.push_back(item.substr(start, end - start + 1)); + } + } + + return result; + } + + // Fetches access token from metadata server + boost::optional getGcpAccessToken(IHttpClient *httpClient) { + const auto url = boost::urls::url(std::string(GCP_METADATA_SERVER_BASE_URL) + "token"); + + const HttpRequest req{ + HttpRequest::Method::GET, + url, + { + {"Metadata-Flavor", "Google"}, + } + }; + + auto responseOpt = httpClient->run(req); + if (!responseOpt) { + CXX_LOG_INFO("No response from GCP metadata server for access token."); + return boost::none; + } + + const auto &response = responseOpt.get(); + if (response.code != 200) { + CXX_LOG_ERROR("GCP metadata server access token request was not successful. Code: %ld", response.code); + return boost::none; + } -namespace Snowflake { - namespace Client { + const std::string response_body = response.getBody(); + picojson::value json; + const std::string err = picojson::parse(json, response_body); + if (!err.empty()) { + CXX_LOG_ERROR("Error parsing GCP access token response: %s", err.c_str()); + return boost::none; + } + + if (!json.is() || !json.get("access_token").is()) { + CXX_LOG_ERROR("No access_token found in GCP response."); + return boost::none; + } + + return json.get("access_token").get(); + } + + // Fetches identity token using delegation chain + boost::optional getIdentityTokenWithDelegation( + IHttpClient *httpClient, + const std::string &accessToken, + const std::vector &serviceAccountChain) { + if (serviceAccountChain.empty()) { + CXX_LOG_ERROR("Service account chain is empty"); + return boost::none; + } + + const std::string &targetServiceAccount = serviceAccountChain.back(); - constexpr const char* SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"; + const std::vector delegates( + serviceAccountChain.begin(), + serviceAccountChain.end() - 1 + ); + + std::string idTokenUrl = std::string(GCP_IAM_CREDENTIALS_BASE_URL) + + "/projects/-/serviceAccounts/" + + targetServiceAccount + ":generateIdToken"; + + picojson::object requestBody; + requestBody["audience"] = picojson::value(std::string(SNOWFLAKE_AUDIENCE)); + requestBody["includeEmail"] = picojson::value(true); + + if (!delegates.empty()) { + picojson::array delegatesArray; + for (const auto &delegate: delegates) { + // Format: projects/-/serviceAccounts/{email} + std::string delegateStr = "projects/-/serviceAccounts/" + delegate; + delegatesArray.emplace_back(delegateStr); + } + requestBody["delegates"] = picojson::value(delegatesArray); + } + + std::string requestBodyStr = picojson::value(requestBody).serialize(); + CXX_LOG_DEBUG("GCP generateIdToken request body: %s", requestBodyStr.c_str()); + + auto url = boost::urls::parse_uri(idTokenUrl); + if (!url) { + CXX_LOG_ERROR("Invalid ID token URL: %s", idTokenUrl.c_str()); + return boost::none; + } - boost::optional createGcpAttestation(AttestationConfig& config) - { - auto url = boost::urls::url("http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity"); + HttpRequest req{ + HttpRequest::Method::POST, + url.value(), + { + {"Authorization", "Bearer " + accessToken}, + {"Content-Type", "application/json"}, + }, + requestBodyStr + }; + + auto responseOpt = httpClient->run(req); + if (!responseOpt) { + CXX_LOG_ERROR("No response from GCP generateIdToken API."); + return boost::none; + } + + const auto &response = responseOpt.get(); + if (response.code != 200) { + CXX_LOG_ERROR("GCP generateIdToken API request failed with code %ld: %s", + response.code, response.getBody().c_str()); + return boost::none; + } + + std::string response_body = response.getBody(); + picojson::value json; + std::string err = picojson::parse(json, response_body); + if (!err.empty()) { + CXX_LOG_ERROR("Error parsing GCP ID token response: %s", err.c_str()); + return boost::none; + } + + if (!json.is() || !json.get("token").is()) { + CXX_LOG_ERROR("No token found in ID token response."); + return boost::none; + } + + return json.get("token").get(); + } + + boost::optional createGcpAttestation(AttestationConfig &config) { + std::string jwtStr; + + // Check if service account impersonation is configured + if (config.workloadIdentityImpersonationPath && + !config.workloadIdentityImpersonationPath.get().empty()) { + CXX_LOG_INFO("Using GCP service account impersonation with delegation"); + + auto serviceAccountChain = parseImpersonationPath( + config.workloadIdentityImpersonationPath.get()); + + if (serviceAccountChain.empty()) { + CXX_LOG_ERROR("Failed to parse service account impersonation path"); + return boost::none; + } + + CXX_LOG_DEBUG("Service account chain size: %zu", serviceAccountChain.size()); + + // Get access token from metadata server + auto accessTokenOpt = getGcpAccessToken(config.httpClient); + if (!accessTokenOpt) { + CXX_LOG_ERROR("Failed to get access token from metadata server"); + return boost::none; + } + + // Get identity token from IAM Credentials API with delegation + auto idTokenOpt = getIdentityTokenWithDelegation( + config.httpClient, + accessTokenOpt.get(), + serviceAccountChain); + if (!idTokenOpt) { + CXX_LOG_ERROR("Failed to get identity token with delegation"); + return boost::none; + } + + jwtStr = idTokenOpt.get(); + } else { + // Get identity token directly from metadata server + CXX_LOG_INFO("Using direct GCP identity token from metadata server"); + + auto url = boost::urls::url(std::string(GCP_METADATA_SERVER_BASE_URL) + "identity"); url.params().append({"audience", SNOWFLAKE_AUDIENCE}); - HttpRequest req { - HttpRequest::Method::GET, - url, - { - {"Metadata-Flavor", "Google"}, - } + HttpRequest req{ + HttpRequest::Method::GET, + url, + { + {"Metadata-Flavor", "Google"}, + } }; auto responseOpt = config.httpClient->run(req); @@ -28,29 +209,29 @@ namespace Snowflake { return boost::none; } - const auto& response = responseOpt.get(); + const auto &response = responseOpt.get(); if (response.code != 200) { CXX_LOG_ERROR("GCP metadata server request was not successful."); return boost::none; } - std::string jwtStr = response.getBody(); + jwtStr = response.getBody(); if (jwtStr.empty()) { CXX_LOG_ERROR("No JWT found in GCP response."); return boost::none; } + } - Jwt::JWTObject jwt(jwtStr); - auto claimSet = jwt.getClaimSet(); - std::string issuer = claimSet->getClaimInString("iss"); - std::string subject = claimSet->getClaimInString("sub"); - if (issuer.empty() || subject.empty()) { - CXX_LOG_ERROR("No issuer or subject found in GCP JWT."); - return boost::none; - } - - return Attestation::makeGcp(jwtStr, issuer, subject); + // Parse JWT and extract issuer/subject + Jwt::JWTObject jwt(jwtStr); + auto claimSet = jwt.getClaimSet(); + std::string issuer = claimSet->getClaimInString("iss"); + std::string subject = claimSet->getClaimInString("sub"); + if (issuer.empty() || subject.empty()) { + CXX_LOG_ERROR("No issuer or subject found in GCP JWT."); + return boost::none; } + + return Attestation::makeGcp(jwtStr, issuer, subject); } } - diff --git a/cpp/http/HttpClient.cpp b/cpp/http/HttpClient.cpp index ef77407ff7..0617abe25e 100644 --- a/cpp/http/HttpClient.cpp +++ b/cpp/http/HttpClient.cpp @@ -23,6 +23,11 @@ namespace Snowflake { curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, SimpleHttpClient::write); curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *) &response); + if (!req.body.empty()) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, req.body.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, req.body.size()); + } + struct curl_slist *header_list = nullptr; for (const auto &h: req.headers) { std::string hdr = h.first + ": " + h.second; diff --git a/include/snowflake/HttpClient.hpp b/include/snowflake/HttpClient.hpp index ca73805999..8a9a1b6b2d 100644 --- a/include/snowflake/HttpClient.hpp +++ b/include/snowflake/HttpClient.hpp @@ -45,6 +45,7 @@ namespace Snowflake { boost::urls::url url; std::map headers; + std::string body{}; }; struct HttpClientConfig { diff --git a/include/snowflake/WifAttestation.hpp b/include/snowflake/WifAttestation.hpp index 120093a0fb..eb8fba6162 100644 --- a/include/snowflake/WifAttestation.hpp +++ b/include/snowflake/WifAttestation.hpp @@ -78,6 +78,7 @@ namespace Client { boost::optional type; boost::optional token; boost::optional snowflakeEntraResource; + boost::optional workloadIdentityImpersonationPath; IHttpClient* httpClient = NULL; AwsUtils::ISdkWrapper* awsSdkWrapper = NULL; }; diff --git a/tests/test_create_wif_attestation.cpp b/tests/test_create_wif_attestation.cpp index b74ee66724..e1f3c1497a 100644 --- a/tests/test_create_wif_attestation.cpp +++ b/tests/test_create_wif_attestation.cpp @@ -191,6 +191,107 @@ void test_unit_aws_attestation_cred_missing(void **) { test_unit_aws_attestation_failed(&awsSdkWrapper); } +// These tests only verify the impersonation path parsing and configuration handling +// not the actual STS calls, which requires valid AWS setup. + +void test_unit_aws_attestation_impersonation_single_role(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + // Will fail at actual STS call, but validates path parsing and code path + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_role_chain(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + "arn:aws:iam::123456789012:role/Role1," + "arn:aws:iam::123456789012:role/Role2," + "arn:aws:iam::123456789012:role/Role3"; + + // Will fail at STS, but validates chain parsing + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_whitespace_handling(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + " arn:aws:iam::123456789012:role/Role1 , " + " arn:aws:iam::123456789012:role/Role2 "; + + // Will fail at STS, but validates whitespace trimming + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_cross_account(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = + "arn:aws:iam::111111111111:role/SourceRole," + "arn:aws:iam::222222222222:role/TargetRole"; + + // Will fail at STS, but validates cross-account ARN parsing + const auto attestationOpt = createAttestation(config); + assert_true(!attestationOpt || attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_empty_path_fallback(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = ""; + + auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); + + const auto& attestation = attestationOpt.get(); + assert_true(attestation.type == AttestationType::AWS); +} + +void test_unit_aws_attestation_impersonation_with_missing_region(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(boost::none, AWS_TEST_CREDS); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + +void test_unit_aws_attestation_impersonation_with_missing_credentials(void **) { + auto awsSdkWrapper = FakeAwsSdkWrapper(AWS_TEST_REGION, Aws::Auth::AWSCredentials()); + + AttestationConfig config; + config.type = AttestationType::AWS; + config.awsSdkWrapper = &awsSdkWrapper; + config.workloadIdentityImpersonationPath = "arn:aws:iam::123456789012:role/TestRole"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + std::vector makeGCPToken(boost::optional issuer, boost::optional subject) { auto jwtObj = Jwt::JWTObject(); jwtObj.getHeader()->setAlgorithm(Jwt::AlgorithmType::RS256); @@ -210,10 +311,12 @@ const std::string GCP_TEST_ISSUER = "https://accounts.google.com"; const std::string GCP_TEST_SUBJECT = "107562638633288735786"; const std::string GCP_TEST_AUDIENCE = "snowflakecomputing.com"; +const std::string GCP_TEST_METADATA_ENDPOINT_HOST = "169.254.169.254"; + FakeHttpClient makeSuccessfulGCPHttpClient(const std::vector &token) { return FakeHttpClient([=](Snowflake::Client::HttpRequest req) { assert_true((*req.url.params().find("audience")).value == GCP_TEST_AUDIENCE); - assert_true(req.url.host() == "169.254.169.254"); + assert_true(req.url.host() == GCP_TEST_METADATA_ENDPOINT_HOST); assert_true(req.url.scheme() == "http"); HttpResponse response; response.code = 200; @@ -284,6 +387,261 @@ void test_unit_gcp_attestation_bad_request(void **) { assert_true(!attestationOpt); } +const std::string GCP_TEST_SUBJECT_ACCESS = "107562638633288735787"; + +const std::string GCP_TEST_IAM_ENDPOINT_HOST = "iamcredentials.googleapis.com"; + +// Multi-path fake HTTP client for GCP service account impersonation +enum class AcceptedHosts { + Metadata, + Iam, + Other +}; + +auto getHost(const std::string& host) -> AcceptedHosts { + if (host == GCP_TEST_METADATA_ENDPOINT_HOST) return AcceptedHosts::Metadata; + if (host == GCP_TEST_IAM_ENDPOINT_HOST) return AcceptedHosts::Iam; + return AcceptedHosts::Other; +} + +FakeHttpClient makeSuccessfulGCPImpersonationHttpClient( + const std::vector& accessToken, + const std::vector& idToken, + const std::vector& expectedDelegates, + const std::string& expectedTargetServiceAccount) { + return FakeHttpClient([=](Snowflake::Client::HttpRequest req) { + HttpResponse response; + response.code = 200; + + switch (getHost(req.url.host())) { + case AcceptedHosts::Metadata: { + if (req.url.encoded_path() == "/computeMetadata/v1/instance/service-accounts/default/token") { + assert_true(req.headers.find("Metadata-Flavor")->second == "Google"); + response.buffer = accessToken; + } + break; + } + case AcceptedHosts::Iam: { + std::string expectedPath = "/v1/projects/-/serviceAccounts/" + + expectedTargetServiceAccount + ":generateIdToken"; + assert_true(req.url.encoded_path() == expectedPath); + assert_true(req.method == HttpRequest::Method::POST); + const auto accessTokenStr = std::string(accessToken.data(), accessToken.size()); + assert_true(req.headers.find("Authorization")->second == "Bearer " + accessTokenStr); + assert_true(req.headers.find("Content-Type")->second == "application/json"); + + picojson::value bodyJson; + std::string err = picojson::parse(bodyJson, req.body); + assert_true(err.empty()); + assert_true(bodyJson.is()); + + auto bodyObj = bodyJson.get(); + assert_true(bodyObj["audience"].get() == GCP_TEST_AUDIENCE); + assert_true(bodyObj["includeEmail"].get() == true); + + if (!expectedDelegates.empty()) { + assert_true(bodyObj.find("delegates") != bodyObj.end()); + auto delegates = bodyObj["delegates"].get(); + assert_true(delegates.size() == expectedDelegates.size()); + for (size_t i = 0; i < expectedDelegates.size(); ++i) { + std::string expected = "projects/-/serviceAccounts/" + expectedDelegates[i]; + assert_true(delegates[i].get() == expected); + } + } + + response.buffer = idToken; + break; + } + case AcceptedHosts::Other: { + // Leave response as default. + break; + } + } + + return response; + }); +} + +void test_unit_gcp_impersonation_single_account_success(void **) { + const auto accessToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT_ACCESS); + const auto idToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT); + const std::string targetServiceAccount = "target@project.iam.gserviceaccount.com"; + + auto fakeHttpClient = makeSuccessfulGCPImpersonationHttpClient( + accessToken, + idToken, + {}, + targetServiceAccount); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + config.workloadIdentityImpersonationPath = targetServiceAccount; + + const auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); + const auto &[type, credential, issuer, subject] = attestationOpt.get(); + assert_true(type == AttestationType::GCP); + assert_true(credential == std::string(idToken.data(), idToken.size())); + assert_true(subject == GCP_TEST_SUBJECT); + assert_true(issuer == GCP_TEST_ISSUER); +} + +void test_unit_gcp_impersonation_chain_success(void **) { + const auto accessToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT_ACCESS); + const auto idToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT); + const std::vector delegates = { + "delegate1@project.iam.gserviceaccount.com", + "delegate2@project.iam.gserviceaccount.com" + }; + const std::string targetServiceAccount = "target@project.iam.gserviceaccount.com"; + + auto fakeHttpClient = makeSuccessfulGCPImpersonationHttpClient( + accessToken, + idToken, + delegates, + targetServiceAccount); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + + std::string workloadIdentityImpersonationPath; + for (const auto &delegate: delegates) { + workloadIdentityImpersonationPath += delegate + ","; + } + workloadIdentityImpersonationPath += targetServiceAccount; + config.workloadIdentityImpersonationPath = workloadIdentityImpersonationPath; + + const auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); + const auto &[type, credential, issuer, subject] = attestationOpt.get(); + assert_true(type == AttestationType::GCP); + assert_true(credential == std::string(idToken.data(), idToken.size())); + assert_true(subject == GCP_TEST_SUBJECT); +} + +void test_unit_gcp_impersonation_whitespace_in_path(void **) { + const auto accessToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT_ACCESS); + const auto idToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT); + const std::vector delegates = { + "delegate1@project.iam.gserviceaccount.com", + "delegate2@project.iam.gserviceaccount.com" + }; + const std::string targetServiceAccount = "target@project.iam.gserviceaccount.com"; + + auto fakeHttpClient = makeSuccessfulGCPImpersonationHttpClient( + accessToken, + idToken, + delegates, + targetServiceAccount); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + + std::string workloadIdentityImpersonationPath = " "; + for (const auto &delegate: delegates) { + workloadIdentityImpersonationPath += " " + delegate + ", "; + } + workloadIdentityImpersonationPath += targetServiceAccount + " "; + config.workloadIdentityImpersonationPath = workloadIdentityImpersonationPath; + + const auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); +} + +void test_unit_gcp_impersonation_access_token_failed(void **) { + auto fakeHttpClient = FakeHttpClient([](const HttpRequest &req) { + if (req.url.host() == GCP_TEST_METADATA_ENDPOINT_HOST) { + HttpResponse response; + response.code = 404; + return boost::optional(response); + } + return boost::optional(boost::none); + }); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + config.workloadIdentityImpersonationPath = "target@project.iam.gserviceaccount.com"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + +void test_unit_gcp_impersonation_id_token_failed(void **) { + const auto accessToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT_ACCESS); + + auto fakeHttpClient = FakeHttpClient([=](const HttpRequest &req) { + if (req.url.host() == GCP_TEST_METADATA_ENDPOINT_HOST) { + HttpResponse response; + response.code = 200; + response.buffer = accessToken; + return boost::optional(response); + } + if (req.url.host() == GCP_TEST_IAM_ENDPOINT_HOST) { + HttpResponse response; + response.code = 403; + const std::string error = "Forbidden"; + response.buffer = std::vector(error.begin(), error.end()); + return boost::optional(response); + } + return boost::optional(boost::none); + }); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + config.workloadIdentityImpersonationPath = "target@project.iam.gserviceaccount.com"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + +void test_unit_gcp_impersonation_empty_path(void **) { + const auto idToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT); + auto fakeHttpClient = makeSuccessfulGCPHttpClient(idToken); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + // Empty path should use direct flow + config.workloadIdentityImpersonationPath = ""; + + const auto attestationOpt = createAttestation(config); + assert_true(attestationOpt.has_value()); +} + +void test_unit_gcp_impersonation_missing_token_in_response(void **) { + const auto accessToken = makeGCPToken(GCP_TEST_ISSUER, GCP_TEST_SUBJECT_ACCESS); + + auto fakeHttpClient = FakeHttpClient([=](const HttpRequest &req) { + if (req.url.host() == GCP_TEST_METADATA_ENDPOINT_HOST) { + HttpResponse response; + response.code = 200; + response.buffer = accessToken; + return boost::optional(response); + } + if (req.url.host() == GCP_TEST_IAM_ENDPOINT_HOST) { + HttpResponse response; + response.code = 200; + const std::string body = "{\"invalid_field\": \"value\"}"; + response.buffer = std::vector(body.begin(), body.end()); + return boost::optional(response); + } + return boost::optional(boost::none); + }); + + AttestationConfig config; + config.type = AttestationType::GCP; + config.httpClient = &fakeHttpClient; + config.workloadIdentityImpersonationPath = "target@project.iam.gserviceaccount.com"; + + const auto attestationOpt = createAttestation(config); + assert_false(attestationOpt.has_value()); +} + const std::string AZURE_TEST_ISSUER_ID = "123bdcc4-50e7-4fea-958d-32cdb3ad3aca"; const std::string AZURE_TEST_SUBJECT = "f05bdcc4-50e7-4fea-958d-32cdb12b3aca"; @@ -483,6 +841,13 @@ int main() { cmocka_unit_test(test_unit_aws_attestation_china_region_success), cmocka_unit_test(test_unit_aws_attestation_region_missing), cmocka_unit_test(test_unit_aws_attestation_cred_missing), + cmocka_unit_test(test_unit_aws_attestation_impersonation_single_role), + cmocka_unit_test(test_unit_aws_attestation_impersonation_role_chain), + cmocka_unit_test(test_unit_aws_attestation_impersonation_whitespace_handling), + cmocka_unit_test(test_unit_aws_attestation_impersonation_cross_account), + cmocka_unit_test(test_unit_aws_attestation_impersonation_empty_path_fallback), + cmocka_unit_test(test_unit_aws_attestation_impersonation_with_missing_region), + cmocka_unit_test(test_unit_aws_attestation_impersonation_with_missing_credentials), cmocka_unit_test(test_unit_gcp_attestation_success), cmocka_unit_test(test_unit_gcp_attestation_missing_issuer), cmocka_unit_test(test_unit_gcp_attestation_missing_subject),