package ssocreds import ( "context" "crypto/sha1" "encoding/hex" "encoding/json" "fmt" "io/ioutil" "path/filepath" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/aws-sdk-go-v2/service/sso" ) // ProviderName is the name of the provider used to specify the source of credentials. const ProviderName = "SSOProvider" var defaultCacheLocation func() string func defaultCacheLocationImpl() string { return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") } func init() { defaultCacheLocation = defaultCacheLocationImpl } // GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation. type GetRoleCredentialsAPIClient interface { GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(*sso.Options)) (*sso.GetRoleCredentialsOutput, error) } // Options is the Provider options structure. type Options struct { // The Client which is configured for the AWS Region where the AWS SSO user portal is located. Client GetRoleCredentialsAPIClient // The AWS account that is assigned to the user. AccountID string // The role name that is assigned to the user. RoleName string // The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. StartURL string } // Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token. type Provider struct { options Options } // New returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured // for the AWS Region where the AWS SSO user portal is located. func New(client GetRoleCredentialsAPIClient, accountID, roleName, startURL string, optFns ...func(options *Options)) *Provider { options := Options{ Client: client, AccountID: accountID, RoleName: roleName, StartURL: startURL, } for _, fn := range optFns { fn(&options) } return &Provider{ options: options, } } // Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal // by exchanging the accessToken present in ~/.aws/sso/cache. func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { tokenFile, err := loadTokenFile(p.options.StartURL) if err != nil { return aws.Credentials{}, err } output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ AccessToken: &tokenFile.AccessToken, AccountId: &p.options.AccountID, RoleName: &p.options.RoleName, }) if err != nil { return aws.Credentials{}, err } return aws.Credentials{ AccessKeyID: aws.ToString(output.RoleCredentials.AccessKeyId), SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey), SessionToken: aws.ToString(output.RoleCredentials.SessionToken), Expires: time.Unix(0, output.RoleCredentials.Expiration*int64(time.Millisecond)).UTC(), CanExpire: true, Source: ProviderName, }, nil } func getCacheFileName(url string) (string, error) { hash := sha1.New() _, err := hash.Write([]byte(url)) if err != nil { return "", err } return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil } type rfc3339 time.Time func (r *rfc3339) UnmarshalJSON(bytes []byte) error { var value string if err := json.Unmarshal(bytes, &value); err != nil { return err } parse, err := time.Parse(time.RFC3339, value) if err != nil { return fmt.Errorf("expected RFC3339 timestamp: %w", err) } *r = rfc3339(parse) return nil } type token struct { AccessToken string `json:"accessToken"` ExpiresAt rfc3339 `json:"expiresAt"` Region string `json:"region,omitempty"` StartURL string `json:"startUrl,omitempty"` } func (t token) Expired() bool { return sdk.NowTime().Round(0).After(time.Time(t.ExpiresAt)) } // InvalidTokenError is the error type that is returned if loaded token has expired or is otherwise invalid. // To refresh the SSO session run aws sso login with the corresponding profile. type InvalidTokenError struct { Err error } func (i *InvalidTokenError) Unwrap() error { return i.Err } func (i *InvalidTokenError) Error() string { const msg = "the SSO session has expired or is invalid" if i.Err == nil { return msg } return msg + ": " + i.Err.Error() } func loadTokenFile(startURL string) (t token, err error) { key, err := getCacheFileName(startURL) if err != nil { return token{}, &InvalidTokenError{Err: err} } fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key)) if err != nil { return token{}, &InvalidTokenError{Err: err} } if err := json.Unmarshal(fileBytes, &t); err != nil { return token{}, &InvalidTokenError{Err: err} } if len(t.AccessToken) == 0 { return token{}, &InvalidTokenError{} } if t.Expired() { return token{}, &InvalidTokenError{Err: fmt.Errorf("access token is expired")} } return t, nil }