Skip to content

Commit d3b9ede

Browse files
committed
feature: Add WIF Impersonation support for AWS
1 parent 40eca7e commit d3b9ede

File tree

1 file changed

+154
-55
lines changed

1 file changed

+154
-55
lines changed

cpp/AwsAttestation.cpp

Lines changed: 154 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,70 +5,169 @@
55
#include "logger/SFLogger.hpp"
66
#include <aws/core/Aws.h>
77
#include <aws/core/auth/AWSCredentialsProvider.h>
8-
#include <aws/core/auth/AWSAuthSigner.h>
9-
#include "snowflake/AWSUtils.hpp"
8+
#include <aws/sts/STSClient.h>
9+
#include <aws/sts/model/AssumeRoleRequest.h>
10+
#include <aws/core/utils/UUID.h>
11+
#include <sstream>
1012

11-
namespace Snowflake {
12-
namespace Client {
13-
14-
// We don't need x-amz-content-sha256 header, because there is no payload to be signed.
15-
// If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with
16-
// "The AWS STS request contained unacceptable headers."
17-
class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer
18-
{
19-
public:
20-
AWSAuthV4SignerNoPayload(const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentialsProvider, const char* serviceName, const Aws::String& region)
21-
: AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; }
22-
};
23-
24-
boost::optional<Attestation> createAwsAttestation(const AttestationConfig& config) {
25-
auto awsSdkInit = AwsUtils::initAwsSdk();
26-
auto creds = config.awsSdkWrapper->getCredentials();
27-
if (creds.IsEmpty()) {
28-
CXX_LOG_INFO("Failed to get AWS credentials");
29-
return boost::none;
13+
namespace Snowflake::Client {
14+
// We don't need x-amz-content-sha256 header, because there is no payload to be signed.
15+
// If x-amz-content-sha256 contain EMPTY_STRING_SHA256, the server responds with
16+
// "The AWS STS request contained unacceptable headers."
17+
class AWS_CORE_API AWSAuthV4SignerNoPayload : public Aws::Client::AWSAuthV4Signer
18+
{
19+
public:
20+
AWSAuthV4SignerNoPayload(const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentialsProvider, const char* serviceName, const Aws::String& region)
21+
: AWSAuthV4Signer(credentialsProvider, serviceName, region) { m_includeSha256HashHeader = false; }
22+
};
23+
24+
// Splits comma-separated role ARN impersonation path
25+
std::vector<std::string> parseRoleArnChain(const std::string &path) {
26+
std::vector<std::string> result;
27+
std::stringstream ss(path);
28+
std::string item;
29+
30+
while (std::getline(ss, item, ',')) {
31+
const auto start = item.find_first_not_of(" \t");
32+
const auto end = item.find_last_not_of(" \t");
33+
34+
if (start != std::string::npos) {
35+
result.push_back(item.substr(start, end - start + 1));
3036
}
37+
}
38+
39+
return result;
40+
}
41+
42+
// Assumes a single AWS role and returns temporary credentials
43+
boost::optional<Aws::Auth::AWSCredentials> assumeAwsRole(
44+
const Aws::Auth::AWSCredentials &currentCreds,
45+
const std::string &roleArn) {
46+
47+
CXX_LOG_DEBUG("Assuming AWS role: %s", roleArn.c_str());
48+
49+
const Aws::STS::STSClient stsClient(currentCreds);
50+
51+
Aws::STS::Model::AssumeRoleRequest assumeRoleRequest;
52+
assumeRoleRequest.SetRoleArn(roleArn.c_str());
53+
54+
const std::string sessionName = "snowflake-wif-" + std::string(Aws::Utils::UUID::PseudoRandomUUID());
55+
assumeRoleRequest.SetRoleSessionName(sessionName.c_str());
56+
assumeRoleRequest.SetDurationSeconds(3600);
57+
58+
const auto outcome = stsClient.AssumeRole(assumeRoleRequest);
59+
60+
if (!outcome.IsSuccess()) {
61+
CXX_LOG_ERROR("Failed to assume role %s: %s",
62+
roleArn.c_str(),
63+
outcome.GetError().GetMessage().c_str());
64+
return boost::none;
65+
}
66+
67+
const auto &credentials = outcome.GetResult().GetCredentials();
68+
return Aws::Auth::AWSCredentials(
69+
credentials.GetAccessKeyId(),
70+
credentials.GetSecretAccessKey(),
71+
credentials.GetSessionToken()
72+
);
73+
}
74+
75+
// Assumes a chain of AWS roles sequentially
76+
boost::optional<Aws::Auth::AWSCredentials> assumeAwsRoleChain(
77+
const Aws::Auth::AWSCredentials &initialCreds,
78+
const std::vector<std::string> &roleArnChain) {
79+
80+
if (roleArnChain.empty()) {
81+
CXX_LOG_ERROR("Role ARN chain is empty");
82+
return boost::none;
83+
}
84+
85+
Aws::Auth::AWSCredentials currentCreds = initialCreds;
3186

32-
auto regionOpt = config.awsSdkWrapper->getEC2Region();
33-
if (!regionOpt) {
34-
CXX_LOG_INFO("Failed to get AWS region");
87+
for (const auto &roleArn: roleArnChain) {
88+
auto assumedCredsOpt = assumeAwsRole(currentCreds, roleArn);
89+
if (!assumedCredsOpt) {
90+
CXX_LOG_ERROR("Failed to assume role in chain: %s", roleArn.c_str());
3591
return boost::none;
3692
}
37-
const std::string& region = regionOpt.get();
38-
const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region);
39-
const std::string host = std::string("sts") + "." + region + "." + domain;
40-
const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15";
41-
42-
auto request = Aws::Http::CreateHttpRequest(
43-
Aws::String(url),
44-
Aws::Http::HttpMethod::HTTP_POST,
45-
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod
46-
);
47-
48-
request->SetHeaderValue("Host", host);
49-
request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com");
50-
51-
auto simpleCredProvider = std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(creds);
52-
AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region);
53-
54-
// Sign the request
55-
if (!signer.SignRequest(*request)) {
56-
CXX_LOG_ERROR("Failed to sign request");
93+
currentCreds = assumedCredsOpt.get();
94+
}
95+
96+
return currentCreds;
97+
}
98+
99+
boost::optional<Attestation> createAwsAttestation(const AttestationConfig &config) {
100+
auto awsSdkInit = AwsUtils::initAwsSdk();
101+
auto creds = config.awsSdkWrapper->getCredentials();
102+
if (creds.IsEmpty()) {
103+
CXX_LOG_INFO("Failed to get AWS credentials");
104+
return boost::none;
105+
}
106+
107+
auto regionOpt = config.awsSdkWrapper->getEC2Region();
108+
if (!regionOpt) {
109+
CXX_LOG_INFO("Failed to get AWS region");
110+
return boost::none;
111+
}
112+
const std::string &region = regionOpt.get();
113+
114+
// Check if role assumption chain is configured
115+
if (config.workloadIdentityImpersonationPath &&
116+
!config.workloadIdentityImpersonationPath.get().empty()) {
117+
118+
CXX_LOG_INFO("Using AWS role assumption chain for impersonation");
119+
120+
auto roleArnChain = parseRoleArnChain(config.workloadIdentityImpersonationPath.get());
121+
122+
if (roleArnChain.empty()) {
123+
CXX_LOG_ERROR("Failed to parse role ARN chain");
57124
return boost::none;
58125
}
59126

60-
picojson::object obj;
61-
obj["url"] = picojson::value(request->GetURIString());
62-
obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod()));
63-
picojson::object headers;
64-
for (const auto &h: request->GetHeaders()) {
65-
headers[h.first] = picojson::value(h.second);
127+
CXX_LOG_DEBUG("Role ARN chain size: %zu", roleArnChain.size());
128+
129+
auto assumedCredsOpt = assumeAwsRoleChain(creds, roleArnChain);
130+
if (!assumedCredsOpt) {
131+
CXX_LOG_ERROR("Failed to assume role chain");
132+
return boost::none;
66133
}
67-
obj["headers"] = picojson::value(headers);
68-
std::string json = picojson::value(obj).serialize(true);
69-
std::string base64;
70-
Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64));
71-
return Attestation::makeAws(base64);
134+
135+
creds = assumedCredsOpt.get();
136+
}
137+
138+
const std::string domain = AwsUtils::getDomainSuffixForRegionalUrl(region);
139+
const std::string host = std::string("sts") + "." + region + "." + domain;
140+
const std::string url = std::string("https://") + host + "/?Action=GetCallerIdentity&Version=2011-06-15";
141+
142+
auto request = Aws::Http::CreateHttpRequest(
143+
Aws::String(url),
144+
Aws::Http::HttpMethod::HTTP_POST,
145+
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod
146+
);
147+
148+
request->SetHeaderValue("Host", host);
149+
request->SetHeaderValue("X-Snowflake-Audience", "snowflakecomputing.com");
150+
151+
auto simpleCredProvider = std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(creds);
152+
AWSAuthV4SignerNoPayload signer(simpleCredProvider, "sts", region);
153+
154+
// Sign the request
155+
if (!signer.SignRequest(*request)) {
156+
CXX_LOG_ERROR("Failed to sign request");
157+
return boost::none;
158+
}
159+
160+
picojson::object obj;
161+
obj["url"] = picojson::value(request->GetURIString());
162+
obj["method"] = picojson::value(Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod()));
163+
picojson::object headers;
164+
for (const auto &h: request->GetHeaders()) {
165+
headers[h.first] = picojson::value(h.second);
72166
}
167+
obj["headers"] = picojson::value(headers);
168+
std::string json = picojson::value(obj).serialize(true);
169+
std::string base64;
170+
Util::Base64::encodePadding(json.begin(), json.end(), std::back_inserter(base64));
171+
return Attestation::makeAws(base64);
73172
}
74173
}

0 commit comments

Comments
 (0)