You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
buildx/vendor/github.com/aws/aws-sdk-go-v2/aws/credential_cache.go

225 lines
7.8 KiB
Go

package aws
import (
"context"
"fmt"
"sync/atomic"
"time"
sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)
// CredentialsCacheOptions are the options
type CredentialsCacheOptions struct {
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired. This can cause an
// increased number of requests to refresh the credentials to occur.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// ExpiryWindowJitterFrac provides a mechanism for randomizing the
// expiration of credentials within the configured ExpiryWindow by a random
// percentage. Valid values are between 0.0 and 1.0.
//
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
// is 0.5 then credentials will be set to expire between 30 to 60 seconds
// prior to their actual expiration time.
//
// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
ExpiryWindowJitterFrac float64
}
// CredentialsCache provides caching and concurrency safe credentials retrieval
// via the provider's retrieve method.
//
// CredentialsCache will look for optional interfaces on the Provider to adjust
// how the credential cache handles credentials caching.
//
// - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
// credential refresh failures. This could return an updated Credentials
// value, or attempt another means of retrieving credentials.
//
// - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
// credentials Expires is modified. This could modify how the Credentials
// Expires is adjusted based on the CredentialsCache ExpiryWindow option.
// Such as providing a floor not to reduce the Expires below.
type CredentialsCache struct {
provider CredentialsProvider
options CredentialsCacheOptions
creds atomic.Value
sf singleflight.Group
}
// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
// is expected to not be nil. A variadic list of one or more functions can be
// provided to modify the CredentialsCache configuration. This allows for
// configuration of credential expiry window and jitter.
func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
options := CredentialsCacheOptions{}
for _, fn := range optFns {
fn(&options)
}
if options.ExpiryWindow < 0 {
options.ExpiryWindow = 0
}
if options.ExpiryWindowJitterFrac < 0 {
options.ExpiryWindowJitterFrac = 0
} else if options.ExpiryWindowJitterFrac > 1 {
options.ExpiryWindowJitterFrac = 1
}
return &CredentialsCache{
provider: provider,
options: options,
}
}
// Retrieve returns the credentials. If the credentials have already been
// retrieved, and not expired the cached credentials will be returned. If the
// credentials have not been retrieved yet, or expired the provider's Retrieve
// method will be called.
//
// Returns and error if the provider's retrieve method returns an error.
func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
if creds, ok := p.getCreds(); ok && !creds.Expired() {
return creds, nil
}
resCh := p.sf.DoChan("", func() (interface{}, error) {
return p.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Credentials), res.Err
case <-ctx.Done():
return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
}
}
func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
currCreds, ok := p.getCreds()
if ok && !currCreds.Expired() {
return currCreds, nil
}
newCreds, err := p.provider.Retrieve(ctx)
if err != nil {
handleFailToRefresh := defaultHandleFailToRefresh
if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
handleFailToRefresh = cs.HandleFailToRefresh
}
newCreds, err = handleFailToRefresh(ctx, currCreds, err)
if err != nil {
return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
}
}
if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
adjustExpiresBy := defaultAdjustExpiresBy
if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
adjustExpiresBy = cs.AdjustExpiresBy
}
randFloat64, err := sdkrand.CryptoRandFloat64()
if err != nil {
return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
}
var jitter time.Duration
if p.options.ExpiryWindowJitterFrac > 0 {
jitter = time.Duration(randFloat64 *
p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
}
newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
if err != nil {
return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
}
}
p.creds.Store(&newCreds)
return newCreds, nil
}
// getCreds returns the currently stored credentials and true. Returning false
// if no credentials were stored.
func (p *CredentialsCache) getCreds() (Credentials, bool) {
v := p.creds.Load()
if v == nil {
return Credentials{}, false
}
c := v.(*Credentials)
if c == nil || !c.HasKeys() {
return Credentials{}, false
}
return *c, true
}
// Invalidate will invalidate the cached credentials. The next call to Retrieve
// will cause the provider's Retrieve method to be called.
func (p *CredentialsCache) Invalidate() {
p.creds.Store((*Credentials)(nil))
}
// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
// matches the target provider type.
func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
return IsCredentialsProvider(p.provider, target)
}
// HandleFailRefreshCredentialsCacheStrategy is an interface for
// CredentialsCache to allow CredentialsProvider how failed to refresh
// credentials is handled.
type HandleFailRefreshCredentialsCacheStrategy interface {
// Given the previously cached Credentials, if any, and refresh error, may
// returns new or modified set of Credentials, or error.
//
// Credential caches may use default implementation if nil.
HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
}
// defaultHandleFailToRefresh returns the passed in error.
func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
return Credentials{}, err
}
// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
// to allow CredentialsProvider to intercept adjustments to Credentials expiry
// based on expectations and use cases of CredentialsProvider.
//
// Credential caches may use default implementation if nil.
type AdjustExpiresByCredentialsCacheStrategy interface {
// Given a Credentials as input, applying any mutations and
// returning the potentially updated Credentials, or error.
AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
}
// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
// and returns the updated credentials value. If Credentials value's CanExpire
// is false, the passed in credentials are returned unchanged.
func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
if !creds.CanExpire {
return creds, nil
}
creds.Expires = creds.Expires.Add(dur)
return creds, nil
}