package ec2rolecreds
import (
"bufio"
"context"
"encoding/json"
"fmt"
"math"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
)
// ProviderName provides a name of EC2Role provider
const ProviderName = "EC2RoleProvider"
// GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
// GetMetadata operation.
type GetMetadataAPIClient interface {
GetMetadata ( context . Context , * imds . GetMetadataInput , ... func ( * imds . Options ) ) ( * imds . GetMetadataOutput , error )
}
// A Provider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// The New function must be used to create the with a custom EC2 IMDS client.
//
// p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
// o.Client = imds.New(imds.Options{/* custom options */})
// })
type Provider struct {
options Options
}
// Options is a list of user settable options for setting the behavior of the Provider.
type Options struct {
// The API client that will be used by the provider to make GetMetadata API
// calls to EC2 IMDS.
//
// If nil, the provider will default to the EC2 IMDS client.
Client GetMetadataAPIClient
}
// New returns an initialized Provider value configured to retrieve
// credentials from EC2 Instance Metadata service.
func New ( optFns ... func ( * Options ) ) * Provider {
options := Options { }
for _ , fn := range optFns {
fn ( & options )
}
if options . Client == nil {
options . Client = imds . New ( imds . Options { } )
}
return & Provider {
options : options ,
}
}
// Retrieve retrieves credentials from the EC2 service. Error will be returned
// if the request fails, or unable to extract the desired credentials.
func ( p * Provider ) Retrieve ( ctx context . Context ) ( aws . Credentials , error ) {
credsList , err := requestCredList ( ctx , p . options . Client )
if err != nil {
return aws . Credentials { Source : ProviderName } , err
}
if len ( credsList ) == 0 {
return aws . Credentials { Source : ProviderName } ,
fmt . Errorf ( "unexpected empty EC2 IMDS role list" )
}
credsName := credsList [ 0 ]
roleCreds , err := requestCred ( ctx , p . options . Client , credsName )
if err != nil {
return aws . Credentials { Source : ProviderName } , err
}
creds := aws . Credentials {
AccessKeyID : roleCreds . AccessKeyID ,
SecretAccessKey : roleCreds . SecretAccessKey ,
SessionToken : roleCreds . Token ,
Source : ProviderName ,
CanExpire : true ,
Expires : roleCreds . Expiration ,
}
// Cap role credentials Expires to 1 hour so they can be refreshed more
// often. Jitter will be applied credentials cache if being used.
if anHour := sdk . NowTime ( ) . Add ( 1 * time . Hour ) ; creds . Expires . After ( anHour ) {
creds . Expires = anHour
}
return creds , nil
}
// HandleFailToRefresh will extend the credentials Expires time if it it is
// expired. If the credentials will not expire within the minimum time, they
// will be returned.
//
// If the credentials cannot expire, the original error will be returned.
func ( p * Provider ) HandleFailToRefresh ( ctx context . Context , prevCreds aws . Credentials , err error ) (
aws . Credentials , error ,
) {
if ! prevCreds . CanExpire {
return aws . Credentials { } , err
}
if prevCreds . Expires . After ( sdk . NowTime ( ) . Add ( 5 * time . Minute ) ) {
return prevCreds , nil
}
newCreds := prevCreds
randFloat64 , err := sdkrand . CryptoRandFloat64 ( )
if err != nil {
return aws . Credentials { } , fmt . Errorf ( "failed to get random float, %w" , err )
}
// Random distribution of [5,15) minutes.
expireOffset := time . Duration ( randFloat64 * float64 ( 10 * time . Minute ) ) + 5 * time . Minute
newCreds . Expires = sdk . NowTime ( ) . Add ( expireOffset )
logger := middleware . GetLogger ( ctx )
logger . Logf ( logging . Warn , "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes." , math . Floor ( expireOffset . Minutes ( ) ) )
return newCreds , nil
}
// AdjustExpiresBy will adds the passed in duration to the passed in
// credential's Expires time, unless the time until Expires is less than 15
// minutes. Returns the credentials, even if not updated.
func ( p * Provider ) AdjustExpiresBy ( creds aws . Credentials , dur time . Duration ) (
aws . Credentials , error ,
) {
if ! creds . CanExpire {
return creds , nil
}
if creds . Expires . Before ( sdk . NowTime ( ) . Add ( 15 * time . Minute ) ) {
return creds , nil
}
creds . Expires = creds . Expires . Add ( dur )
return creds , nil
}
// ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses.
type ec2RoleCredRespBody struct {
// Success State
Expiration time . Time
AccessKeyID string
SecretAccessKey string
Token string
// Error state
Code string
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service. If
// there are no credentials, or there is an error making or receiving the
// request
func requestCredList ( ctx context . Context , client GetMetadataAPIClient ) ( [ ] string , error ) {
resp , err := client . GetMetadata ( ctx , & imds . GetMetadataInput {
Path : iamSecurityCredsPath ,
} )
if err != nil {
return nil , fmt . Errorf ( "no EC2 IMDS role found, %w" , err )
}
defer resp . Content . Close ( )
credsList := [ ] string { }
s := bufio . NewScanner ( resp . Content )
for s . Scan ( ) {
credsList = append ( credsList , s . Text ( ) )
}
if err := s . Err ( ) ; err != nil {
return nil , fmt . Errorf ( "failed to read EC2 IMDS role, %w" , err )
}
return credsList , nil
}
// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred ( ctx context . Context , client GetMetadataAPIClient , credsName string ) ( ec2RoleCredRespBody , error ) {
resp , err := client . GetMetadata ( ctx , & imds . GetMetadataInput {
Path : path . Join ( iamSecurityCredsPath , credsName ) ,
} )
if err != nil {
return ec2RoleCredRespBody { } ,
fmt . Errorf ( "failed to get %s EC2 IMDS role credentials, %w" ,
credsName , err )
}
defer resp . Content . Close ( )
var respCreds ec2RoleCredRespBody
if err := json . NewDecoder ( resp . Content ) . Decode ( & respCreds ) ; err != nil {
return ec2RoleCredRespBody { } ,
fmt . Errorf ( "failed to decode %s EC2 IMDS role credentials, %w" ,
credsName , err )
}
if ! strings . EqualFold ( respCreds . Code , "Success" ) {
// If an error code was returned something failed requesting the role.
return ec2RoleCredRespBody { } ,
fmt . Errorf ( "failed to get %s EC2 IMDS role credentials, %w" ,
credsName ,
& smithy . GenericAPIError { Code : respCreds . Code , Message : respCreds . Message } )
}
return respCreds , nil
}