55#include " logger/SFLogger.hpp"
66#include < aws/core/Aws.h>
77#include < aws/core/auth/AWSCredentialsProvider.h>
8- #include < aws/core/auth/AWSAuthSigner.h>
9- #include " snowflake/AWSUtils.hpp"
8+ #include < aws/sts/STSClient.h>
9+ #include < aws/sts/model/AssumeRoleRequest.h>
10+ #include < aws/core/utils/UUID.h>
11+ #include < sstream>
1012
11- namespace Snowflake {
12- namespace Client {
13-
14- // We don't need x-amz-content-sha256 header, because there is no payload to be signed.
15- // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with
16- // "The AWS STS request contained unacceptable headers."
17- class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer
18- {
19- public:
20- AWSAuthV4SignerNoPayload (const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentialsProvider, const char * serviceName, const Aws::String& region)
21- : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false ; }
22- };
23-
24- boost::optional<Attestation> createAwsAttestation (const AttestationConfig& config) {
25- auto awsSdkInit = AwsUtils::initAwsSdk ();
26- auto creds = config.awsSdkWrapper ->getCredentials ();
27- if (creds.IsEmpty ()) {
28- CXX_LOG_INFO (" Failed to get AWS credentials" );
29- return boost::none;
13+ namespace Snowflake ::Client {
14+ // We don't need x-amz-content-sha256 header, because there is no payload to be signed.
15+ // If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with
16+ // "The AWS STS request contained unacceptable headers."
17+ class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer
18+ {
19+ public:
20+ AWSAuthV4SignerNoPayload (const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentialsProvider, const char * serviceName, const Aws::String& region)
21+ : AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false ; }
22+ };
23+
24+ // Splits comma-separated role ARN impersonation path
25+ std::vector<std::string> parseRoleArnChain (const std::string &path) {
26+ std::vector<std::string> result;
27+ std::stringstream ss (path);
28+ std::string item;
29+
30+ while (std::getline (ss, item, ' ,' )) {
31+ const auto start = item.find_first_not_of (" \t " );
32+ const auto end = item.find_last_not_of (" \t " );
33+
34+ if (start != std::string::npos) {
35+ result.push_back (item.substr (start, end - start + 1 ));
3036 }
37+ }
38+
39+ return result;
40+ }
41+
42+ // Assumes a single AWS role and returns temporary credentials
43+ boost::optional<Aws::Auth::AWSCredentials> assumeAwsRole (
44+ const Aws::Auth::AWSCredentials ¤tCreds,
45+ const std::string &roleArn) {
46+
47+ CXX_LOG_DEBUG (" Assuming AWS role: %s" , roleArn.c_str ());
48+
49+ const Aws::STS::STSClient stsClient (currentCreds);
50+
51+ Aws::STS::Model::AssumeRoleRequest assumeRoleRequest;
52+ assumeRoleRequest.SetRoleArn (roleArn.c_str ());
53+
54+ const std::string sessionName = " snowflake-wif-" + std::string (Aws::Utils::UUID::PseudoRandomUUID ());
55+ assumeRoleRequest.SetRoleSessionName (sessionName.c_str ());
56+ assumeRoleRequest.SetDurationSeconds (3600 );
57+
58+ const auto outcome = stsClient.AssumeRole (assumeRoleRequest);
59+
60+ if (!outcome.IsSuccess ()) {
61+ CXX_LOG_ERROR (" Failed to assume role %s: %s" ,
62+ roleArn.c_str (),
63+ outcome.GetError ().GetMessage ().c_str ());
64+ return boost::none;
65+ }
66+
67+ const auto &credentials = outcome.GetResult ().GetCredentials ();
68+ return Aws::Auth::AWSCredentials (
69+ credentials.GetAccessKeyId (),
70+ credentials.GetSecretAccessKey (),
71+ credentials.GetSessionToken ()
72+ );
73+ }
74+
75+ // Assumes a chain of AWS roles sequentially
76+ boost::optional<Aws::Auth::AWSCredentials> assumeAwsRoleChain (
77+ const Aws::Auth::AWSCredentials &initialCreds,
78+ const std::vector<std::string> &roleArnChain) {
79+
80+ if (roleArnChain.empty ()) {
81+ CXX_LOG_ERROR (" Role ARN chain is empty" );
82+ return boost::none;
83+ }
84+
85+ Aws::Auth::AWSCredentials currentCreds = initialCreds;
3186
32- auto regionOpt = config.awsSdkWrapper ->getEC2Region ();
33- if (!regionOpt) {
34- CXX_LOG_INFO (" Failed to get AWS region" );
87+ for (const auto &roleArn: roleArnChain) {
88+ auto assumedCredsOpt = assumeAwsRole (currentCreds, roleArn);
89+ if (!assumedCredsOpt) {
90+ CXX_LOG_ERROR (" Failed to assume role in chain: %s" , roleArn.c_str ());
3591 return boost::none;
3692 }
37- const std::string& region = regionOpt.get ();
38- const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl (region);
39- const std::string host = std::string (" sts" ) + " ." + region + " ." + domain;
40- const std::string url = std::string (" https://" ) + host + " /?Action=GetCallerIdentity&Version=2011-06-15" ;
41-
42- auto request = Aws::Http::CreateHttpRequest (
43- Aws::String (url),
44- Aws::Http::HttpMethod::HTTP_POST,
45- Aws::Utils::Stream::DefaultResponseStreamFactoryMethod
46- );
47-
48- request->SetHeaderValue (" Host" , host);
49- request->SetHeaderValue (" X-Snowflake-Audience" , " snowflakecomputing.com" );
50-
51- auto simpleCredProvider = std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(creds);
52- AWSAuthV4SignerNoPayload signer (simpleCredProvider, " sts" , region);
53-
54- // Sign the request
55- if (!signer.SignRequest (*request)) {
56- CXX_LOG_ERROR (" Failed to sign request" );
93+ currentCreds = assumedCredsOpt.get ();
94+ }
95+
96+ return currentCreds;
97+ }
98+
99+ boost::optional<Attestation> createAwsAttestation (const AttestationConfig &config) {
100+ auto awsSdkInit = AwsUtils::initAwsSdk ();
101+ auto creds = config.awsSdkWrapper ->getCredentials ();
102+ if (creds.IsEmpty ()) {
103+ CXX_LOG_INFO (" Failed to get AWS credentials" );
104+ return boost::none;
105+ }
106+
107+ auto regionOpt = config.awsSdkWrapper ->getEC2Region ();
108+ if (!regionOpt) {
109+ CXX_LOG_INFO (" Failed to get AWS region" );
110+ return boost::none;
111+ }
112+ const std::string ®ion = regionOpt.get ();
113+
114+ // Check if role assumption chain is configured
115+ if (config.workloadIdentityImpersonationPath &&
116+ !config.workloadIdentityImpersonationPath .get ().empty ()) {
117+
118+ CXX_LOG_INFO (" Using AWS role assumption chain for impersonation" );
119+
120+ auto roleArnChain = parseRoleArnChain (config.workloadIdentityImpersonationPath .get ());
121+
122+ if (roleArnChain.empty ()) {
123+ CXX_LOG_ERROR (" Failed to parse role ARN chain" );
57124 return boost::none;
58125 }
59126
60- picojson::object obj ;
61- obj[ " url " ] = picojson::value (request-> GetURIString ());
62- obj[ " method " ] = picojson::value ( Aws::Http::HttpMethodMapper::GetNameForHttpMethod (request-> GetMethod ()) );
63- picojson::object headers;
64- for ( const auto &h: request-> GetHeaders ()) {
65- headers[h. first ] = picojson::value (h. second ) ;
127+ CXX_LOG_DEBUG ( " Role ARN chain size: %zu " , roleArnChain. size ()) ;
128+
129+ auto assumedCredsOpt = assumeAwsRoleChain (creds, roleArnChain );
130+ if (!assumedCredsOpt) {
131+ CXX_LOG_ERROR ( " Failed to assume role chain " );
132+ return boost::none ;
66133 }
67- obj[" headers" ] = picojson::value (headers);
68- std::string json = picojson::value (obj).serialize (true );
69- std::string base64;
70- Util::Base64::encodePadding (json.begin (), json.end (), std::back_inserter (base64));
71- return Attestation::makeAws (base64);
134+
135+ creds = assumedCredsOpt.get ();
136+ }
137+
138+ const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl (region);
139+ const std::string host = std::string (" sts" ) + " ." + region + " ." + domain;
140+ const std::string url = std::string (" https://" ) + host + " /?Action=GetCallerIdentity&Version=2011-06-15" ;
141+
142+ auto request = Aws::Http::CreateHttpRequest (
143+ Aws::String (url),
144+ Aws::Http::HttpMethod::HTTP_POST,
145+ Aws::Utils::Stream::DefaultResponseStreamFactoryMethod
146+ );
147+
148+ request->SetHeaderValue (" Host" , host);
149+ request->SetHeaderValue (" X-Snowflake-Audience" , " snowflakecomputing.com" );
150+
151+ auto simpleCredProvider = std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(creds);
152+ AWSAuthV4SignerNoPayload signer (simpleCredProvider, " sts" , region);
153+
154+ // Sign the request
155+ if (!signer.SignRequest (*request)) {
156+ CXX_LOG_ERROR (" Failed to sign request" );
157+ return boost::none;
158+ }
159+
160+ picojson::object obj;
161+ obj[" url" ] = picojson::value (request->GetURIString ());
162+ obj[" method" ] = picojson::value (Aws::Http::HttpMethodMapper::GetNameForHttpMethod (request->GetMethod ()));
163+ picojson::object headers;
164+ for (const auto &h: request->GetHeaders ()) {
165+ headers[h.first ] = picojson::value (h.second );
72166 }
167+ obj[" headers" ] = picojson::value (headers);
168+ std::string json = picojson::value (obj).serialize (true );
169+ std::string base64;
170+ Util::Base64::encodePadding (json.begin (), json.end (), std::back_inserter (base64));
171+ return Attestation::makeAws (base64);
73172 }
74173}
0 commit comments