Skip to content

Commit d55d6d2

Browse files
Add support to assume an AWS role and renew expired credentials (#653)
Co-authored-by: Christoph Burmeister <[email protected]> Signed-off-by: Steve Teuber <[email protected]> Signed-off-by: Steve Teuber <[email protected]> Co-authored-by: Christoph Burmeister <[email protected]>
1 parent c2b9c71 commit d55d6d2

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ elasticsearch_exporter --help
4949
```
5050

5151
| Argument | Introduced in Version | Description | Default |
52-
| -------- | --------------------- | ----------- | ----------- |
52+
| ----------------------- | --------------------- | ----------- | ----------- |
5353
| es.uri | 1.0.2 | Address (host and port) of the Elasticsearch node we should connect to. This could be a local node (`localhost:9200`, for instance), or the address of a remote Elasticsearch server. When basic auth is needed, specify as: `<proto>://<user>:<password>@<host>:<port>`. E.G., `http://admin:pass@localhost:9200`. Special characters in the user credentials need to be URL-encoded. | <http://localhost:9200> |
5454
| es.all | 1.0.2 | If true, query stats for all nodes in the cluster, rather than just the node we connect to. | false |
5555
| es.cluster_settings | 1.1.0rc1 | If true, query stats for cluster settings. | false |
@@ -70,6 +70,7 @@ elasticsearch_exporter --help
7070
| web.listen-address | 1.0.2 | Address to listen on for web interface and telemetry. | :9114 |
7171
| web.telemetry-path | 1.0.2 | Path under which to expose metrics. | /metrics |
7272
| aws.region | 1.5.0 | Region for AWS elasticsearch | |
73+
| aws.role-arn | 1.6.0 | Role ARN of an IAM role to assume. | |
7374
| version | 1.0.2 | Show version info on stdout and exit. | |
7475

7576
Commandline parameters start with a single `-` for versions less than `1.1.0rc1`.

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ go 1.19
55
require (
66
github.com/aws/aws-sdk-go-v2 v1.17.3
77
github.com/aws/aws-sdk-go-v2/config v1.18.7
8+
github.com/aws/aws-sdk-go-v2/credentials v1.13.7
9+
github.com/aws/aws-sdk-go-v2/service/sts v1.17.7
810
github.com/blang/semver/v4 v4.0.0
911
github.com/go-kit/log v0.2.1
1012
github.com/imdario/mergo v0.3.13
@@ -17,15 +19,13 @@ require (
1719
require (
1820
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect
1921
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
20-
github.com/aws/aws-sdk-go-v2/credentials v1.13.7 // indirect
2122
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.21 // indirect
2223
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect
2324
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect
2425
github.com/aws/aws-sdk-go-v2/internal/ini v1.3.28 // indirect
2526
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.21 // indirect
2627
github.com/aws/aws-sdk-go-v2/service/sso v1.11.28 // indirect
2728
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.11 // indirect
28-
github.com/aws/aws-sdk-go-v2/service/sts v1.17.7 // indirect
2929
github.com/aws/smithy-go v1.13.5 // indirect
3030
github.com/beorn7/perks v1.0.1 // indirect
3131
github.com/cespare/xxhash/v2 v2.1.2 // indirect

main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ func main() {
125125
awsRegion = kingpin.Flag("aws.region",
126126
"Region for AWS elasticsearch").
127127
Default("").String()
128+
awsRoleArn = kingpin.Flag("aws.role-arn",
129+
"Role ARN of an IAM role to assume.").
130+
Default("").String()
128131
)
129132

130133
kingpin.Version(version.Print(name))
@@ -174,7 +177,7 @@ func main() {
174177
}
175178

176179
if *awsRegion != "" {
177-
httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, logger)
180+
httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, *awsRoleArn, logger)
178181
if err != nil {
179182
_ = level.Error(logger).Log("msg", "failed to create AWS transport", "err", err)
180183
os.Exit(1)

pkg/roundtripper/roundtripper.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
"github.com/aws/aws-sdk-go-v2/aws"
2727
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
2828
"github.com/aws/aws-sdk-go-v2/config"
29+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
30+
"github.com/aws/aws-sdk-go-v2/service/sts"
2931
"github.com/go-kit/log"
3032
"github.com/go-kit/log/level"
3133
)
@@ -36,21 +38,28 @@ const (
3638

3739
type AWSSigningTransport struct {
3840
t http.RoundTripper
39-
creds aws.Credentials
41+
creds aws.CredentialsProvider
4042
region string
4143
log log.Logger
4244
}
4345

44-
func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.Logger) (*AWSSigningTransport, error) {
46+
func NewAWSSigningTransport(transport http.RoundTripper, region string, roleArn string, log log.Logger) (*AWSSigningTransport, error) {
4547
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region))
4648
if err != nil {
47-
_ = level.Error(log).Log("msg", "fail to load aws default config", "err", err)
49+
_ = level.Error(log).Log("msg", "failed to load aws default config", "err", err)
4850
return nil, err
4951
}
5052

51-
creds, err := cfg.Credentials.Retrieve(context.Background())
53+
if roleArn != "" {
54+
cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), roleArn)
55+
}
56+
57+
creds := aws.NewCredentialsCache(cfg.Credentials)
58+
// Run a single fetch credentials operation to ensure that the credentials
59+
// are valid before returning the transport.
60+
_, err = cfg.Credentials.Retrieve(context.Background())
5261
if err != nil {
53-
_ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err)
62+
_ = level.Error(log).Log("msg", "failed to retrive aws credentials", "err", err)
5463
return nil, err
5564
}
5665

@@ -66,13 +75,20 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro
6675
signer := v4.NewSigner()
6776
payloadHash, newReader, err := hashPayload(req.Body)
6877
if err != nil {
69-
_ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err)
78+
_ = level.Error(a.log).Log("msg", "failed to hash request body", "err", err)
7079
return nil, err
7180
}
7281
req.Body = newReader
73-
err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now())
82+
83+
creds, err := a.creds.Retrieve(context.Background())
84+
if err != nil {
85+
_ = level.Error(a.log).Log("msg", "failed to retrieve aws credentials", "err", err)
86+
return nil, err
87+
}
88+
89+
err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now())
7490
if err != nil {
75-
_ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err)
91+
_ = level.Error(a.log).Log("msg", "failed to sign request body", "err", err)
7692
return nil, err
7793
}
7894
return a.t.RoundTrip(req)

0 commit comments

Comments
 (0)