package bearer import ( "context" "fmt" "sync/atomic" "time" smithycontext "github.com/aws/smithy-go/context" "github.com/aws/smithy-go/internal/sync/singleflight" ) // package variable that can be override in unit tests. var timeNow = time.Now // TokenCacheOptions provides a set of optional configuration options for the // TokenCache TokenProvider. type TokenCacheOptions struct { // The duration before the token will expire when the credentials will be // refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls // will be blocking. // // Asynchronous refreshes are deduplicated, and only one will be in-flight // at a time. If the token expires while an asynchronous refresh is in // flight, the next call to RetrieveBearerToken will block on that refresh // to return. RefreshBeforeExpires time.Duration // The timeout the underlying TokenProvider's RetrieveBearerToken call must // return within, or will be canceled. Defaults to 0, no timeout. // // If 0 timeout, its possible for the underlying tokenProvider's // RetrieveBearerToken call to block forever. Preventing subsequent // TokenCache attempts to refresh the token. // // If this timeout is reached all pending deduplicated calls to // TokenCache RetrieveBearerToken will fail with an error. RetrieveBearerTokenTimeout time.Duration // The minimum duration between asynchronous refresh attempts. If the next // asynchronous recent refresh attempt was within the minimum delay // duration, the call to retrieve will return the current cached token, if // not expired. // // The asynchronous retrieve is deduplicated across multiple calls when // RetrieveBearerToken is called. The asynchronous retrieve is not a // periodic task. It is only performed when the token has not yet expired, // and the current item is within the RefreshBeforeExpires window, and the // TokenCache's RetrieveBearerToken method is called. // // If 0, (default) there will be no minimum delay between asynchronous // refresh attempts. // // If DisableAsyncRefresh is true, this option is ignored. AsyncRefreshMinimumDelay time.Duration // Sets if the TokenCache will attempt to refresh the token in the // background asynchronously instead of blocking for credentials to be // refreshed. If disabled token refresh will be blocking. // // The first call to RetrieveBearerToken will always be blocking, because // there is no cached token. DisableAsyncRefresh bool } // TokenCache provides an utility to cache Bearer Authentication tokens from a // wrapped TokenProvider. The TokenCache can be has options to configure the // cache's early and asynchronous refresh of the token. type TokenCache struct { options TokenCacheOptions provider TokenProvider cachedToken atomic.Value lastRefreshAttemptTime atomic.Value sfGroup singleflight.Group } // NewTokenCache returns a initialized TokenCache that implements the // TokenProvider interface. Wrapping the provider passed in. Also taking a set // of optional functional option parameters to configure the token cache. func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache { var options TokenCacheOptions for _, fn := range optFns { fn(&options) } return &TokenCache{ options: options, provider: provider, } } // RetrieveBearerToken returns the token if it could be obtained, or error if a // valid token could not be retrieved. // // The passed in Context's cancel/deadline/timeout will impacting only this // individual retrieve call and not any other already queued up calls. This // means underlying provider's RetrieveBearerToken calls could block for ever, // and not be canceled with the Context. Set RetrieveBearerTokenTimeout to // provide a timeout, preventing the underlying TokenProvider blocking forever. // // By default, if the passed in Context is canceled, all of its values will be // considered expired. The wrapped TokenProvider will not be able to lookup the // values from the Context once it is expired. This is done to protect against // expired values no longer being valid. To disable this behavior, use // smithy-go's context.WithPreserveExpiredValues to add a value to the Context // before calling RetrieveBearerToken to enable support for expired values. // // Without RetrieveBearerTokenTimeout there is the potential for a underlying // Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent // attempts at refreshing the token. func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) { cachedToken, ok := p.getCachedToken() if !ok || cachedToken.Expired(timeNow()) { return p.refreshBearerToken(ctx) } // Check if the token should be refreshed before it expires. refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires)) if !refreshToken { return cachedToken, nil } if p.options.DisableAsyncRefresh { return p.refreshBearerToken(ctx) } p.tryAsyncRefresh(ctx) return cachedToken, nil } // tryAsyncRefresh attempts to asynchronously refresh the token returning the // already cached token. If it AsyncRefreshMinimumDelay option is not zero, and // the duration since the last refresh is less than that value, nothing will be // done. func (p *TokenCache) tryAsyncRefresh(ctx context.Context) { if p.options.AsyncRefreshMinimumDelay != 0 { var lastRefreshAttempt time.Time if v := p.lastRefreshAttemptTime.Load(); v != nil { lastRefreshAttempt = v.(time.Time) } if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) { return } } // Ignore the returned channel so this won't be blocking, and limit the // number of additional goroutines created. p.sfGroup.DoChan("async-refresh", func() (interface{}, error) { res, err := p.refreshBearerToken(ctx) if p.options.AsyncRefreshMinimumDelay != 0 { var refreshAttempt time.Time if err != nil { refreshAttempt = timeNow() } p.lastRefreshAttemptTime.Store(refreshAttempt) } return res, err }) } func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) { resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) { ctx := smithycontext.WithSuppressCancel(ctx) if v := p.options.RetrieveBearerTokenTimeout; v != 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, v) defer cancel() } return p.singleRetrieve(ctx) }) select { case res := <-resCh: return res.Val.(Token), res.Err case <-ctx.Done(): return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err()) } } func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) { token, err := p.provider.RetrieveBearerToken(ctx) if err != nil { return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err) } p.cachedToken.Store(&token) return token, nil } // getCachedToken returns the currently cached token and true if found. Returns // false if no token is cached. func (p *TokenCache) getCachedToken() (Token, bool) { v := p.cachedToken.Load() if v == nil { return Token{}, false } t := v.(*Token) if t == nil || t.Value == "" { return Token{}, false } return *t, true }