You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
4.9 KiB
Go
185 lines
4.9 KiB
Go
2 years ago
|
package ssocreds
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/sha1"
|
||
|
"encoding/hex"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||
|
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||
|
"github.com/aws/aws-sdk-go-v2/service/sso"
|
||
|
)
|
||
|
|
||
|
// ProviderName is the name of the provider used to specify the source of credentials.
|
||
|
const ProviderName = "SSOProvider"
|
||
|
|
||
|
var defaultCacheLocation func() string
|
||
|
|
||
|
func defaultCacheLocationImpl() string {
|
||
|
return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
defaultCacheLocation = defaultCacheLocationImpl
|
||
|
}
|
||
|
|
||
|
// GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation.
|
||
|
type GetRoleCredentialsAPIClient interface {
|
||
|
GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(*sso.Options)) (*sso.GetRoleCredentialsOutput, error)
|
||
|
}
|
||
|
|
||
|
// Options is the Provider options structure.
|
||
|
type Options struct {
|
||
|
// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
|
||
|
Client GetRoleCredentialsAPIClient
|
||
|
|
||
|
// The AWS account that is assigned to the user.
|
||
|
AccountID string
|
||
|
|
||
|
// The role name that is assigned to the user.
|
||
|
RoleName string
|
||
|
|
||
|
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
|
||
|
StartURL string
|
||
|
}
|
||
|
|
||
|
// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
|
||
|
type Provider struct {
|
||
|
options Options
|
||
|
}
|
||
|
|
||
|
// New returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
|
||
|
// for the AWS Region where the AWS SSO user portal is located.
|
||
|
func New(client GetRoleCredentialsAPIClient, accountID, roleName, startURL string, optFns ...func(options *Options)) *Provider {
|
||
|
options := Options{
|
||
|
Client: client,
|
||
|
AccountID: accountID,
|
||
|
RoleName: roleName,
|
||
|
StartURL: startURL,
|
||
|
}
|
||
|
|
||
|
for _, fn := range optFns {
|
||
|
fn(&options)
|
||
|
}
|
||
|
|
||
|
return &Provider{
|
||
|
options: options,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
|
||
|
// by exchanging the accessToken present in ~/.aws/sso/cache.
|
||
|
func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
|
||
|
tokenFile, err := loadTokenFile(p.options.StartURL)
|
||
|
if err != nil {
|
||
|
return aws.Credentials{}, err
|
||
|
}
|
||
|
|
||
|
output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{
|
||
|
AccessToken: &tokenFile.AccessToken,
|
||
|
AccountId: &p.options.AccountID,
|
||
|
RoleName: &p.options.RoleName,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return aws.Credentials{}, err
|
||
|
}
|
||
|
|
||
|
return aws.Credentials{
|
||
|
AccessKeyID: aws.ToString(output.RoleCredentials.AccessKeyId),
|
||
|
SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey),
|
||
|
SessionToken: aws.ToString(output.RoleCredentials.SessionToken),
|
||
|
Expires: time.Unix(0, output.RoleCredentials.Expiration*int64(time.Millisecond)).UTC(),
|
||
|
CanExpire: true,
|
||
|
Source: ProviderName,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func getCacheFileName(url string) (string, error) {
|
||
|
hash := sha1.New()
|
||
|
_, err := hash.Write([]byte(url))
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
|
||
|
}
|
||
|
|
||
|
type rfc3339 time.Time
|
||
|
|
||
|
func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
|
||
|
var value string
|
||
|
|
||
|
if err := json.Unmarshal(bytes, &value); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
parse, err := time.Parse(time.RFC3339, value)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("expected RFC3339 timestamp: %w", err)
|
||
|
}
|
||
|
|
||
|
*r = rfc3339(parse)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type token struct {
|
||
|
AccessToken string `json:"accessToken"`
|
||
|
ExpiresAt rfc3339 `json:"expiresAt"`
|
||
|
Region string `json:"region,omitempty"`
|
||
|
StartURL string `json:"startUrl,omitempty"`
|
||
|
}
|
||
|
|
||
|
func (t token) Expired() bool {
|
||
|
return sdk.NowTime().Round(0).After(time.Time(t.ExpiresAt))
|
||
|
}
|
||
|
|
||
|
// InvalidTokenError is the error type that is returned if loaded token has expired or is otherwise invalid.
|
||
|
// To refresh the SSO session run aws sso login with the corresponding profile.
|
||
|
type InvalidTokenError struct {
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
func (i *InvalidTokenError) Unwrap() error {
|
||
|
return i.Err
|
||
|
}
|
||
|
|
||
|
func (i *InvalidTokenError) Error() string {
|
||
|
const msg = "the SSO session has expired or is invalid"
|
||
|
if i.Err == nil {
|
||
|
return msg
|
||
|
}
|
||
|
return msg + ": " + i.Err.Error()
|
||
|
}
|
||
|
|
||
|
func loadTokenFile(startURL string) (t token, err error) {
|
||
|
key, err := getCacheFileName(startURL)
|
||
|
if err != nil {
|
||
|
return token{}, &InvalidTokenError{Err: err}
|
||
|
}
|
||
|
|
||
|
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
|
||
|
if err != nil {
|
||
|
return token{}, &InvalidTokenError{Err: err}
|
||
|
}
|
||
|
|
||
|
if err := json.Unmarshal(fileBytes, &t); err != nil {
|
||
|
return token{}, &InvalidTokenError{Err: err}
|
||
|
}
|
||
|
|
||
|
if len(t.AccessToken) == 0 {
|
||
|
return token{}, &InvalidTokenError{}
|
||
|
}
|
||
|
|
||
|
if t.Expired() {
|
||
|
return token{}, &InvalidTokenError{Err: fmt.Errorf("access token is expired")}
|
||
|
}
|
||
|
|
||
|
return t, nil
|
||
|
}
|