feat(train): add new command to interact with aws and train models

This commit is contained in:
2021-10-17 19:15:44 +02:00
parent 5436dfebc2
commit 538cea18f2
1064 changed files with 282251 additions and 89305 deletions

View File

@ -0,0 +1,106 @@
package arn
import (
"fmt"
"strings"
awsarn "github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn"
)
const (
s3Namespace = "s3"
s3ObjectsLambdaNamespace = "s3-object-lambda"
s3OutpostsNamespace = "s3-outposts"
)
// ParseEndpointARN parses a given generic aws ARN into a s3 arn resource.
func ParseEndpointARN(v awsarn.ARN) (arn.Resource, error) {
return arn.ParseResource(v, accessPointResourceParser)
}
func accessPointResourceParser(a awsarn.ARN) (arn.Resource, error) {
resParts := arn.SplitResource(a.Resource)
switch resParts[0] {
case "accesspoint":
switch a.Service {
case s3Namespace:
return arn.ParseAccessPointResource(a, resParts[1:])
case s3ObjectsLambdaNamespace:
return parseS3ObjectLambdaAccessPointResource(a, resParts)
default:
return arn.AccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: fmt.Sprintf("service is not %s or %s", s3Namespace, s3ObjectsLambdaNamespace)}
}
case "outpost":
if a.Service != s3OutpostsNamespace {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "service is not %s"}
}
return parseOutpostAccessPointResource(a, resParts[1:])
default:
return nil, arn.InvalidARNError{ARN: a, Reason: "unknown resource type"}
}
}
func parseOutpostAccessPointResource(a awsarn.ARN, resParts []string) (arn.OutpostAccessPointARN, error) {
// outpost accesspoint arn is only valid if service is s3-outposts
if a.Service != "s3-outposts" {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "service is not s3-outposts"}
}
if len(resParts) == 0 {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "outpost resource-id not set"}
}
if len(resParts) < 3 {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{
ARN: a, Reason: "access-point resource not set in Outpost ARN",
}
}
resID := strings.TrimSpace(resParts[0])
if len(resID) == 0 {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "outpost resource-id not set"}
}
var outpostAccessPointARN = arn.OutpostAccessPointARN{}
switch resParts[1] {
case "accesspoint":
// Do not allow region-less outpost access-point arns.
if len(a.Region) == 0 {
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "region is not set"}
}
accessPointARN, err := arn.ParseAccessPointResource(a, resParts[2:])
if err != nil {
return arn.OutpostAccessPointARN{}, err
}
// set access-point arn
outpostAccessPointARN.AccessPointARN = accessPointARN
default:
return arn.OutpostAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: "access-point resource not set in Outpost ARN"}
}
// set outpost id
outpostAccessPointARN.OutpostID = resID
return outpostAccessPointARN, nil
}
func parseS3ObjectLambdaAccessPointResource(a awsarn.ARN, resParts []string) (arn.S3ObjectLambdaAccessPointARN, error) {
if a.Service != s3ObjectsLambdaNamespace {
return arn.S3ObjectLambdaAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: fmt.Sprintf("service is not %s", s3ObjectsLambdaNamespace)}
}
if len(a.Region) == 0 {
return arn.S3ObjectLambdaAccessPointARN{}, arn.InvalidARNError{ARN: a, Reason: fmt.Sprintf("%s region not set", s3ObjectsLambdaNamespace)}
}
accessPointARN, err := arn.ParseAccessPointResource(a, resParts[1:])
if err != nil {
return arn.S3ObjectLambdaAccessPointARN{}, err
}
return arn.S3ObjectLambdaAccessPointARN{
AccessPointARN: accessPointARN,
}, nil
}

View File

@ -0,0 +1,87 @@
/*
Package customizations provides customizations for the Amazon S3 API client.
This package provides support for following S3 customizations
ProcessARN Middleware: processes an ARN if provided as input and updates the endpoint as per the arn type
UpdateEndpoint Middleware: resolves a custom endpoint as per s3 config options
RemoveBucket Middleware: removes a serialized bucket name from request url path
processResponseWith200Error Middleware: Deserializing response error with 200 status code
Virtual Host style url addressing
Since serializers serialize by default as path style url, we use customization
to modify the endpoint url when `UsePathStyle` option on S3Client is unset or
false. This flag will be ignored if `UseAccelerate` option is set to true.
If UseAccelerate is not enabled, and the bucket name is not a valid hostname
label, they SDK will fallback to forcing the request to be made as if
UsePathStyle was enabled. This behavior is also used if UseDualStack is enabled.
https://docs.aws.amazon.com/AmazonS3/latest/dev/dual-stack-endpoints.html#dual-stack-endpoints-description
Transfer acceleration
By default S3 Transfer acceleration support is disabled. By enabling `UseAccelerate`
option on S3Client, one can enable s3 transfer acceleration support. Transfer
acceleration only works with Virtual Host style addressing, and thus `UsePathStyle`
option if set is ignored. Transfer acceleration is not supported for S3 operations
DeleteBucket, ListBuckets, and CreateBucket.
Dualstack support
By default dualstack support for s3 client is disabled. By enabling `UseDualstack`
option on s3 client, you can enable dualstack endpoint support.
Endpoint customizations
Customizations to lookup ARN, process ARN needs to happen before request serialization.
UpdateEndpoint middleware which mutates resources based on Options such as
UseDualstack, UseAccelerate for modifying resolved endpoint are executed after
request serialization. Remove bucket middleware is executed after
an request is serialized, and removes the serialized bucket name from request path
Middleware layering:
Initialize : HTTP Request -> ARN Lookup -> Input-Validation -> Serialize step
Serialize : HTTP Request -> Process ARN -> operation serializer -> Update-Endpoint customization -> Remove-Bucket -> next middleware
Customization options:
UseARNRegion (Disabled by Default)
UsePathStyle (Disabled by Default)
UseAccelerate (Disabled by Default)
UseDualstack (Disabled by Default)
Handle Error response with 200 status code
S3 operations: CopyObject, CompleteMultipartUpload, UploadPartCopy can have an
error Response with status code 2xx. The processResponseWith200Error middleware
customizations enables SDK to check for an error within response body prior to
deserialization.
As the check for 2xx response containing an error needs to be performed earlier
than response deserialization. Since the behavior of Deserialization is in
reverse order to the other stack steps its easier to consider that "after" means
"before".
Middleware layering:
HTTP Response -> handle 200 error customization -> deserialize
*/
package customizations

View File

@ -0,0 +1,74 @@
package customizations
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"strings"
"github.com/aws/smithy-go"
smithyxml "github.com/aws/smithy-go/encoding/xml"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// HandleResponseErrorWith200Status check for S3 200 error response.
// If an s3 200 error is found, status code for the response is modified temporarily to
// 5xx response status code.
func HandleResponseErrorWith200Status(stack *middleware.Stack) error {
return stack.Deserialize.Insert(&processResponseFor200ErrorMiddleware{}, "OperationDeserializer", middleware.After)
}
// middleware to process raw response and look for error response with 200 status code
type processResponseFor200ErrorMiddleware struct{}
// ID returns the middleware ID.
func (*processResponseFor200ErrorMiddleware) ID() string {
return "S3:ProcessResponseFor200Error"
}
func (m *processResponseFor200ErrorMiddleware) HandleDeserialize(
ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
if err != nil {
return out, metadata, err
}
response, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
}
// check if response status code is 2xx.
if response.StatusCode < 200 || response.StatusCode >= 300 {
return
}
var readBuff bytes.Buffer
body := io.TeeReader(response.Body, &readBuff)
rootDecoder := xml.NewDecoder(body)
t, err := smithyxml.FetchRootElement(rootDecoder)
if err == io.EOF {
return out, metadata, &smithy.DeserializationError{
Err: fmt.Errorf("received empty response payload"),
}
}
// rewind response body
response.Body = ioutil.NopCloser(io.MultiReader(&readBuff, response.Body))
// if start tag is "Error", the response is consider error response.
if strings.EqualFold(t.Name.Local, "Error") {
// according to https://aws.amazon.com/premiumsupport/knowledge-center/s3-resolve-200-internalerror/
// 200 error responses are similar to 5xx errors.
response.StatusCode = 500
}
return out, metadata, err
}

View File

@ -0,0 +1,22 @@
package customizations
import (
"github.com/aws/smithy-go/transport/http"
"strings"
)
func updateS3HostForS3AccessPoint(req *http.Request) {
updateHostPrefix(req, "s3", s3AccessPoint)
}
func updateS3HostForS3ObjectLambda(req *http.Request) {
updateHostPrefix(req, "s3", s3ObjectLambda)
}
func updateHostPrefix(req *http.Request, oldEndpointPrefix, newEndpointPrefix string) {
host := req.URL.Host
if strings.HasPrefix(host, oldEndpointPrefix) {
// For example if oldEndpointPrefix=s3 would replace to newEndpointPrefix
req.URL.Host = newEndpointPrefix + host[len(oldEndpointPrefix):]
}
}

View File

@ -0,0 +1,49 @@
package customizations
import (
"context"
"fmt"
"strconv"
"time"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// AddExpiresOnPresignedURL represents a build middleware used to assign
// expiration on a presigned URL.
type AddExpiresOnPresignedURL struct {
// Expires is time.Duration within which presigned url should be expired.
// This should be the duration in seconds the presigned URL should be considered valid for.
// By default the S3 presigned url expires in 15 minutes ie. 900 seconds.
Expires time.Duration
}
// ID representing the middleware
func (*AddExpiresOnPresignedURL) ID() string {
return "S3:AddExpiresOnPresignedURL"
}
// HandleBuild handles the build step middleware behavior
func (m *AddExpiresOnPresignedURL) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
// if expiration is unset skip this middleware
if m.Expires == 0 {
// default to 15 * time.Minutes
m.Expires = 15 * time.Minute
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type %T", req)
}
// set S3 X-AMZ-Expires header
query := req.URL.Query()
query.Set("X-Amz-Expires", strconv.FormatInt(int64(m.Expires/time.Second), 10))
req.URL.RawQuery = query.Encode()
return next.HandleBuild(ctx, in)
}

View File

@ -0,0 +1,588 @@
package customizations
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn"
s3arn "github.com/aws/aws-sdk-go-v2/service/s3/internal/arn"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/endpoints"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a"
)
const (
s3AccessPoint = "s3-accesspoint"
s3ObjectLambda = "s3-object-lambda"
)
// processARNResource is used to process an ARN resource.
type processARNResource struct {
// UseARNRegion indicates if region parsed from an ARN should be used.
UseARNRegion bool
// UseAccelerate indicates if s3 transfer acceleration is enabled
UseAccelerate bool
// UseDualstack instructs if s3 dualstack endpoint config is enabled
UseDualstack bool
// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
EndpointResolver EndpointResolver
// EndpointResolverOptions used by endpoint resolver
EndpointResolverOptions EndpointResolverOptions
// DisableMultiRegionAccessPoints indicates multi-region access point support is disabled
DisableMultiRegionAccessPoints bool
}
// ID returns the middleware ID.
func (*processARNResource) ID() string { return "S3:ProcessARNResource" }
func (m *processARNResource) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// check if arn was provided, if not skip this middleware
arnValue, ok := s3shared.GetARNResourceFromContext(ctx)
if !ok {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// parse arn into an endpoint arn wrt to service
resource, err := s3arn.ParseEndpointARN(arnValue)
if err != nil {
return out, metadata, err
}
// build a resource request struct
resourceRequest := s3shared.ResourceRequest{
Resource: resource,
UseARNRegion: m.UseARNRegion,
RequestRegion: awsmiddleware.GetRegion(ctx),
SigningRegion: awsmiddleware.GetSigningRegion(ctx),
PartitionID: awsmiddleware.GetPartitionID(ctx),
}
// switch to correct endpoint updater
switch tv := resource.(type) {
case arn.AccessPointARN:
// multi-region arns do not need to validate for cross partition request
if len(tv.Region) != 0 {
// validate resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
}
// Special handling for region-less ap-arns.
if len(tv.Region) == 0 {
// check if multi-region arn support is disabled
if m.DisableMultiRegionAccessPoints {
return out, metadata, fmt.Errorf("Invalid configuration, Multi-Region access point ARNs are disabled")
}
// Do not allow dual-stack configuration with multi-region arns.
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// fetch arn region to resolve request
resolveRegion := tv.Region
// check if request region is FIPS
if resourceRequest.UseFips() {
// Do not allow Fips support within multi-region arns.
if len(resolveRegion) == 0 {
return out, metadata, s3shared.NewClientConfiguredForFIPSError(
tv, resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// if use arn region is enabled and request signing region is not same as arn region
if m.UseARNRegion && resourceRequest.IsCrossRegion() {
// FIPS with cross region is not supported, the SDK must fail
// because there is no well defined method for SDK to construct a
// correct FIPS endpoint.
return out, metadata,
s3shared.NewClientConfiguredForCrossRegionFIPSError(
tv,
resourceRequest.PartitionID,
resourceRequest.RequestRegion,
nil,
)
}
// if use arn region is NOT set, we should use the request region
resolveRegion = resourceRequest.RequestRegion
}
var requestBuilder func(context.Context, accesspointOptions) (context.Context, error)
if len(resolveRegion) == 0 {
requestBuilder = buildMultiRegionAccessPointsRequest
} else {
requestBuilder = buildAccessPointRequest
}
// build request as per accesspoint builder
ctx, err = requestBuilder(ctx, accesspointOptions{
processARNResource: *m,
request: req,
resource: tv,
resolveRegion: resolveRegion,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
case arn.S3ObjectLambdaAccessPointARN:
// validate region for resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if dualstack
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// fetch arn region to resolve request
resolveRegion := tv.Region
if resourceRequest.UseFips() {
// if use arn region is enabled and request signing region is not same as arn region
if m.UseARNRegion && resourceRequest.IsCrossRegion() {
// FIPS with cross region is not supported, the SDK must fail
// because there is no well defined method for SDK to construct a
// correct FIPS endpoint.
return out, metadata,
s3shared.NewClientConfiguredForCrossRegionFIPSError(
tv,
resourceRequest.PartitionID,
resourceRequest.RequestRegion,
nil,
)
}
// if use arn region is NOT set, we should use the request region
resolveRegion = resourceRequest.RequestRegion
}
// build access point request
ctx, err = buildS3ObjectLambdaAccessPointRequest(ctx, accesspointOptions{
processARNResource: *m,
request: req,
resource: tv.AccessPointARN,
resolveRegion: resolveRegion,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
// process outpost accesspoint ARN
case arn.OutpostAccessPointARN:
// validate region for resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if dual stack
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if request region is FIPS
if resourceRequest.UseFips() {
return out, metadata, s3shared.NewFIPSConfigurationError(tv, resourceRequest.PartitionID,
resourceRequest.RequestRegion, nil)
}
// build outpost access point request
ctx, err = buildOutpostAccessPointRequest(ctx, outpostAccessPointOptions{
processARNResource: *m,
resource: tv,
request: req,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
default:
return out, metadata, s3shared.NewInvalidARNError(resource, nil)
}
return next.HandleSerialize(ctx, in)
}
// validate if s3 resource and request region config is compatible.
func validateRegionForResourceRequest(resourceRequest s3shared.ResourceRequest) error {
// check if resourceRequest leads to a cross partition error
v, err := resourceRequest.IsCrossPartition()
if err != nil {
return err
}
if v {
// if cross partition
return s3shared.NewClientPartitionMismatchError(resourceRequest.Resource,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if resourceRequest leads to a cross region error
if !resourceRequest.AllowCrossRegion() && resourceRequest.IsCrossRegion() {
// if cross region, but not use ARN region is not enabled
return s3shared.NewClientRegionMismatchError(resourceRequest.Resource,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
return nil
}
// === Accesspoint ==========
type accesspointOptions struct {
processARNResource
request *http.Request
resource arn.AccessPointARN
resolveRegion string
partitionID string
requestRegion string
}
func buildAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := options.resolveRegion
resolveService := tv.Service
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
// Must sign with s3-object-lambda
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to "s3-accesspoint"
ctx = awsmiddleware.SetServiceID(ctx, s3AccessPoint)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
updateS3HostForS3AccessPoint(req)
ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
if err != nil {
return ctx, err
}
return ctx, nil
}
func buildS3ObjectLambdaAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := options.resolveRegion
resolveService := tv.Service
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
// Must sign with s3-object-lambda
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to "s3-object-lambda"
ctx = awsmiddleware.SetServiceID(ctx, s3ObjectLambda)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
if endpoint.Source == aws.EndpointSourceServiceMetadata {
updateS3HostForS3ObjectLambda(req)
}
ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
if err != nil {
return ctx, err
}
return ctx, nil
}
func buildMultiRegionAccessPointsRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
const s3GlobalLabel = "s3-global."
const accesspointLabel = "accesspoint."
tv := options.resource
req := options.request
resolveService := tv.Service
resolveRegion := options.requestRegion
arnPartition := tv.Partition
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// set signing region and version for MRAP
endpoint.SigningRegion = "*"
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
ctx = SetSignerVersion(ctx, v4a.Version)
if len(endpoint.SigningName) != 0 {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
// modify endpoint host to use s3-global host prefix
scheme := strings.SplitN(endpoint.URL, "://", 2)
dnsSuffix, err := endpoints.GetDNSSuffix(arnPartition)
if err != nil {
return ctx, fmt.Errorf("Error determining dns suffix from arn partition, %w", err)
}
// set url as per partition
endpoint.URL = scheme[0] + "://" + s3GlobalLabel + dnsSuffix
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
// build access point host prefix
accessPointHostPrefix := tv.AccessPointName + "." + accesspointLabel
// add host prefix to url
req.URL.Host = accessPointHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = accessPointHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, fmt.Errorf("endpoint validation error: %w, when using arn %v", err, tv)
}
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
return ctx, nil
}
func buildAccessPointHostPrefix(ctx context.Context, req *http.Request, tv arn.AccessPointARN) (context.Context, error) {
// add host prefix for access point
accessPointHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "."
req.URL.Host = accessPointHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = accessPointHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, s3shared.NewInvalidARNError(tv, err)
}
return ctx, nil
}
// ====== Outpost Accesspoint ========
type outpostAccessPointOptions struct {
processARNResource
request *http.Request
resource arn.OutpostAccessPointARN
partitionID string
requestRegion string
}
func buildOutpostAccessPointRequest(ctx context.Context, options outpostAccessPointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := tv.Region
resolveService := tv.Service
endpointsID := resolveService
if strings.EqualFold(resolveService, "s3-outposts") {
// assign endpoints ID as "S3"
endpointsID = "s3"
}
// resolve regional endpoint for resolved region.
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
// assign resolved service from arn as signing name
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
// redirect signer to use resolved endpoint signing name and region
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to resolved service id
ctx = awsmiddleware.SetServiceID(ctx, resolveService)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip further customizations, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
updateHostPrefix(req, endpointsID, resolveService)
// add host prefix for s3-outposts
outpostAPHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "." + tv.OutpostID + "."
req.URL.Host = outpostAPHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = outpostAPHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, s3shared.NewInvalidARNError(tv, err)
}
return ctx, nil
}

View File

@ -0,0 +1,58 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
)
// removeBucketFromPathMiddleware needs to be executed after serialize step is performed
type removeBucketFromPathMiddleware struct {
}
func (m *removeBucketFromPathMiddleware) ID() string {
return "S3:RemoveBucketFromPathMiddleware"
}
func (m *removeBucketFromPathMiddleware) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// check if a bucket removal from HTTP path is required
bucket, ok := getRemoveBucketFromPath(ctx)
if !ok {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
removeBucketFromPath(req.URL, bucket)
return next.HandleSerialize(ctx, in)
}
type removeBucketKey struct {
bucket string
}
// setBucketToRemoveOnContext sets the bucket name to be removed.
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func setBucketToRemoveOnContext(ctx context.Context, bucket string) context.Context {
return middleware.WithStackValue(ctx, removeBucketKey{}, bucket)
}
// getRemoveBucketFromPath returns the bucket name to remove from the path.
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func getRemoveBucketFromPath(ctx context.Context) (string, bool) {
v, ok := middleware.GetStackValue(ctx, removeBucketKey{}).(string)
return v, ok
}

View File

@ -0,0 +1,85 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
"net/url"
)
type s3ObjectLambdaEndpoint struct {
// whether the operation should use the s3-object-lambda endpoint
UseEndpoint bool
// use dualstack
UseDualstack bool
// use transfer acceleration
UseAccelerate bool
EndpointResolver EndpointResolver
EndpointResolverOptions EndpointResolverOptions
}
func (t *s3ObjectLambdaEndpoint) ID() string {
return "S3:ObjectLambdaEndpoint"
}
func (t *s3ObjectLambdaEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
if !t.UseEndpoint {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type: %T", in.Request)
}
if t.UseDualstack {
return out, metadata, fmt.Errorf("client configured for dualstack but not supported for operation")
}
if t.UseAccelerate {
return out, metadata, fmt.Errorf("client configured for accelerate but not supported for operation")
}
region := awsmiddleware.GetRegion(ctx)
endpoint, err := t.EndpointResolver.ResolveEndpoint(region, t.EndpointResolverOptions)
if err != nil {
return out, metadata, err
}
// Set the ServiceID and SigningName
ctx = awsmiddleware.SetServiceID(ctx, s3ObjectLambda)
if len(endpoint.SigningName) > 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, s3ObjectLambda)
}
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return out, metadata, err
}
if len(endpoint.SigningRegion) > 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, region)
}
if endpoint.Source == aws.EndpointSourceServiceMetadata || !endpoint.HostnameImmutable {
updateS3HostForS3ObjectLambda(req)
}
return next.HandleSerialize(ctx, in)
}

View File

@ -0,0 +1,213 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a"
"github.com/aws/smithy-go/middleware"
)
type signerVersionKey struct{}
// GetSignerVersion retrieves the signer version to use for signing
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func GetSignerVersion(ctx context.Context) (v string) {
v, _ = middleware.GetStackValue(ctx, signerVersionKey{}).(string)
return v
}
// SetSignerVersion sets the signer version to be used for signing the request
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func SetSignerVersion(ctx context.Context, version string) context.Context {
return middleware.WithStackValue(ctx, signerVersionKey{}, version)
}
// SignHTTPRequestMiddlewareOptions is the configuration options for the SignHTTPRequestMiddleware middleware.
type SignHTTPRequestMiddlewareOptions struct {
// credential provider
CredentialsProvider aws.CredentialsProvider
// log signing
LogSigning bool
// v4 signer
V4Signer v4.HTTPSigner
//v4a signer
V4aSigner v4a.HTTPSigner
}
// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
return &SignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
v4Signer: options.V4Signer,
v4aSigner: options.V4aSigner,
logSigning: options.LogSigning,
}
}
// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation to select HTTP Signing method
type SignHTTPRequestMiddleware struct {
// credential provider
credentialsProvider aws.CredentialsProvider
// log signing
logSigning bool
// v4 signer
v4Signer v4.HTTPSigner
//v4a signer
v4aSigner v4a.HTTPSigner
}
// ID is the SignHTTPRequestMiddleware identifier
func (s *SignHTTPRequestMiddleware) ID() string {
return "Signing"
}
// HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
// fetch signer type from context
signerVersion := GetSignerVersion(ctx)
switch signerVersion {
case v4a.Version:
v4aCredentialProvider, ok := s.credentialsProvider.(v4a.CredentialsProvider)
if !ok {
return out, metadata, fmt.Errorf("invalid credential-provider provided for sigV4a Signer")
}
mw := v4a.NewSignHTTPRequestMiddleware(v4a.SignHTTPRequestMiddlewareOptions{
Credentials: v4aCredentialProvider,
Signer: s.v4aSigner,
LogSigning: s.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
default:
mw := v4.NewSignHTTPRequestMiddleware(v4.SignHTTPRequestMiddlewareOptions{
CredentialsProvider: s.credentialsProvider,
Signer: s.v4Signer,
LogSigning: s.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
}
}
// RegisterSigningMiddleware registers the wrapper signing middleware to the stack. If a signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterSigningMiddleware(stack *middleware.Stack, signingMiddleware *SignHTTPRequestMiddleware) (err error) {
const signedID = "Signing"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
type PresignHTTPRequestMiddlewareOptions struct {
CredentialsProvider aws.CredentialsProvider
V4Presigner v4.HTTPPresigner
V4aPresigner v4a.HTTPPresigner
LogSigning bool
}
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
// presigned URL for an HTTP request.
//
// Will short circuit the middleware stack and not forward onto the next
// Finalize handler.
type PresignHTTPRequestMiddleware struct {
// cred provider and signer for sigv4
credentialsProvider aws.CredentialsProvider
// sigV4 signer
v4Signer v4.HTTPPresigner
// sigV4a signer
v4aSigner v4a.HTTPPresigner
// log signing
logSigning bool
}
// NewPresignHTTPRequestMiddleware constructs a PresignHTTPRequestMiddleware using the given Signer for signing requests
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
return &PresignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
v4Signer: options.V4Presigner,
v4aSigner: options.V4aPresigner,
logSigning: options.LogSigning,
}
}
// ID provides the middleware ID.
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
// HandleFinalize will take the provided input and create a presigned url for
// the http request using the SigV4 or SigV4a presign authentication scheme.
//
// Since the signed request is not a valid HTTP request
func (p *PresignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
// fetch signer type from context
signerVersion := GetSignerVersion(ctx)
switch signerVersion {
case v4a.Version:
v4aCredentialProvider, ok := p.credentialsProvider.(v4a.CredentialsProvider)
if !ok {
return out, metadata, fmt.Errorf("invalid credential-provider provided for sigV4a Signer")
}
mw := v4a.NewPresignHTTPRequestMiddleware(v4a.PresignHTTPRequestMiddlewareOptions{
CredentialsProvider: v4aCredentialProvider,
Presigner: p.v4aSigner,
LogSigning: p.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
default:
mw := v4.NewPresignHTTPRequestMiddleware(v4.PresignHTTPRequestMiddlewareOptions{
CredentialsProvider: p.credentialsProvider,
Presigner: p.v4Signer,
LogSigning: p.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
}
}
// RegisterPreSigningMiddleware registers the wrapper pre-signing middleware to the stack. If a pre-signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterPreSigningMiddleware(stack *middleware.Stack, signingMiddleware *PresignHTTPRequestMiddleware) (err error) {
const signedID = "PresignHTTPRequest"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}

View File

@ -0,0 +1,318 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/smithy-go/encoding/httpbinding"
"log"
"net/url"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
internalendpoints "github.com/aws/aws-sdk-go-v2/service/s3/internal/endpoints"
)
// EndpointResolver interface for resolving service endpoints.
type EndpointResolver interface {
ResolveEndpoint(region string, options EndpointResolverOptions) (aws.Endpoint, error)
}
// EndpointResolverOptions is the service endpoint resolver options
type EndpointResolverOptions = internalendpoints.Options
// UpdateEndpointParameterAccessor represents accessor functions used by the middleware
type UpdateEndpointParameterAccessor struct {
// functional pointer to fetch bucket name from provided input.
// The function is intended to take an input value, and
// return a string pointer to value of string, and bool if
// input has no bucket member.
GetBucketFromInput func(interface{}) (*string, bool)
}
// UpdateEndpointOptions provides the options for the UpdateEndpoint middleware setup.
type UpdateEndpointOptions struct {
// Accessor are parameter accessors used by the middleware
Accessor UpdateEndpointParameterAccessor
// use path style
UsePathStyle bool
// use transfer acceleration
UseAccelerate bool
// indicates if an operation supports s3 transfer acceleration.
SupportsAccelerate bool
// use dualstack
UseDualstack bool
// use ARN region
UseARNRegion bool
// Indicates that the operation should target the s3-object-lambda endpoint.
// Used to direct operations that do not route based on an input ARN.
TargetS3ObjectLambda bool
// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
EndpointResolver EndpointResolver
// EndpointResolverOptions used by endpoint resolver
EndpointResolverOptions EndpointResolverOptions
// DisableMultiRegionAccessPoints indicates multi-region access point support is disabled
DisableMultiRegionAccessPoints bool
}
// UpdateEndpoint adds the middleware to the middleware stack based on the UpdateEndpointOptions.
func UpdateEndpoint(stack *middleware.Stack, options UpdateEndpointOptions) (err error) {
// initial arn look up middleware
err = stack.Initialize.Add(&s3shared.ARNLookup{
GetARNValue: options.Accessor.GetBucketFromInput,
}, middleware.Before)
if err != nil {
return err
}
// process arn
err = stack.Serialize.Insert(&processARNResource{
UseARNRegion: options.UseARNRegion,
UseAccelerate: options.UseAccelerate,
UseDualstack: options.UseDualstack,
EndpointResolver: options.EndpointResolver,
EndpointResolverOptions: options.EndpointResolverOptions,
DisableMultiRegionAccessPoints: options.DisableMultiRegionAccessPoints,
}, "OperationSerializer", middleware.Before)
if err != nil {
return err
}
// process whether the operation requires the s3-object-lambda endpoint
// Occurs before operation serializer so that hostPrefix mutations
// can be handled correctly.
err = stack.Serialize.Insert(&s3ObjectLambdaEndpoint{
UseEndpoint: options.TargetS3ObjectLambda,
UseAccelerate: options.UseAccelerate,
UseDualstack: options.UseDualstack,
EndpointResolver: options.EndpointResolver,
EndpointResolverOptions: options.EndpointResolverOptions,
}, "OperationSerializer", middleware.Before)
if err != nil {
return err
}
// remove bucket arn middleware
err = stack.Serialize.Insert(&removeBucketFromPathMiddleware{}, "OperationSerializer", middleware.After)
if err != nil {
return err
}
// enable dual stack support
err = stack.Serialize.Insert(&s3shared.EnableDualstack{
UseDualstack: options.UseDualstack,
DefaultServiceID: "s3",
}, "OperationSerializer", middleware.After)
if err != nil {
return err
}
// update endpoint to use options for path style and accelerate
err = stack.Serialize.Insert(&updateEndpoint{
usePathStyle: options.UsePathStyle,
getBucketFromInput: options.Accessor.GetBucketFromInput,
useAccelerate: options.UseAccelerate,
supportsAccelerate: options.SupportsAccelerate,
}, (*s3shared.EnableDualstack)(nil).ID(), middleware.After)
if err != nil {
return err
}
return err
}
type updateEndpoint struct {
// path style options
usePathStyle bool
getBucketFromInput func(interface{}) (*string, bool)
// accelerate options
useAccelerate bool
supportsAccelerate bool
}
// ID returns the middleware ID.
func (*updateEndpoint) ID() string {
return "S3:UpdateEndpoint"
}
func (u *updateEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// if arn was processed, skip this middleware
if _, ok := s3shared.GetARNResourceFromContext(ctx); ok {
return next.HandleSerialize(ctx, in)
}
// skip this customization if host name is set as immutable
if smithyhttp.GetHostnameImmutable(ctx) {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// check if accelerate is supported
if u.useAccelerate && !u.supportsAccelerate {
// accelerate is not supported, thus will be ignored
log.Println("Transfer acceleration is not supported for the operation, ignoring UseAccelerate.")
u.useAccelerate = false
}
// transfer acceleration is not supported with path style urls
if u.useAccelerate && u.usePathStyle {
log.Println("UseAccelerate is not compatible with UsePathStyle, ignoring UsePathStyle.")
u.usePathStyle = false
}
if u.getBucketFromInput != nil {
// Below customization only apply if bucket name is provided
bucket, ok := u.getBucketFromInput(in.Parameters)
if ok && bucket != nil {
region := awsmiddleware.GetRegion(ctx)
if err := u.updateEndpointFromConfig(req, *bucket, region); err != nil {
return out, metadata, err
}
}
}
return next.HandleSerialize(ctx, in)
}
func (u updateEndpoint) updateEndpointFromConfig(req *smithyhttp.Request, bucket string, region string) error {
// do nothing if path style is enforced
if u.usePathStyle {
return nil
}
if !hostCompatibleBucketName(req.URL, bucket) {
// bucket name must be valid to put into the host for accelerate operations.
// For non-accelerate operations the bucket name can stay in the path if
// not valid hostname.
var err error
if u.useAccelerate {
err = fmt.Errorf("bucket name %s is not compatible with S3", bucket)
}
// No-Op if not using accelerate.
return err
}
// accelerate is only supported if use path style is disabled
if u.useAccelerate {
parts := strings.Split(req.URL.Host, ".")
if len(parts) < 3 {
return fmt.Errorf("unable to update endpoint host for S3 accelerate, hostname invalid, %s", req.URL.Host)
}
if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
parts[0] = "s3-accelerate"
}
for i := 1; i+1 < len(parts); i++ {
if strings.EqualFold(parts[i], region) {
parts = append(parts[:i], parts[i+1:]...)
break
}
}
// construct the url host
req.URL.Host = strings.Join(parts, ".")
}
// move bucket to follow virtual host style
moveBucketNameToHost(req.URL, bucket)
return nil
}
// updates endpoint to use virtual host styling
func moveBucketNameToHost(u *url.URL, bucket string) {
u.Host = bucket + "." + u.Host
removeBucketFromPath(u, bucket)
}
// remove bucket from url
func removeBucketFromPath(u *url.URL, bucket string) {
if strings.HasPrefix(u.Path, "/"+bucket) {
// modify url path
u.Path = strings.Replace(u.Path, "/"+bucket, "", 1)
// modify url raw path
u.RawPath = strings.Replace(u.RawPath, "/"+httpbinding.EscapePath(bucket, true), "", 1)
}
if u.Path == "" {
u.Path = "/"
}
if u.RawPath == "" {
u.RawPath = "/"
}
}
// hostCompatibleBucketName returns true if the request should
// put the bucket in the host. This is false if S3ForcePathStyle is
// explicitly set or if the bucket is not DNS compatible.
func hostCompatibleBucketName(u *url.URL, bucket string) bool {
// Bucket might be DNS compatible but dots in the hostname will fail
// certificate validation, so do not use host-style.
if u.Scheme == "https" && strings.Contains(bucket, ".") {
return false
}
// if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
// Buckets created outside of the classic region MUST be DNS compatible.
func dnsCompatibleBucketName(bucket string) bool {
if strings.Contains(bucket, "..") {
return false
}
// checks for `^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$` domain mapping
if !((bucket[0] > 96 && bucket[0] < 123) || (bucket[0] > 47 && bucket[0] < 58)) {
return false
}
for _, c := range bucket[1:] {
if !((c > 96 && c < 123) || (c > 47 && c < 58) || c == 46 || c == 45) {
return false
}
}
// checks for `^(\d+\.){3}\d+$` IPaddressing
v := strings.SplitN(bucket, ".", -1)
if len(v) == 4 {
for _, c := range bucket {
if !((c > 47 && c < 58) || c == 46) {
// we confirm that this is not a IP address
return true
}
}
// this is a IP address
return false
}
return true
}

View File

@ -0,0 +1,375 @@
// Code generated by smithy-go-codegen DO NOT EDIT.
package endpoints
import (
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/internal/endpoints"
"regexp"
"strings"
)
// Options is the endpoint resolver configuration options
type Options struct {
DisableHTTPS bool
}
// Resolver S3 endpoint resolver
type Resolver struct {
partitions endpoints.Partitions
}
// ResolveEndpoint resolves the service endpoint for the given region and options
func (r *Resolver) ResolveEndpoint(region string, options Options) (endpoint aws.Endpoint, err error) {
if len(region) == 0 {
return endpoint, &aws.MissingRegionError{}
}
opt := endpoints.Options{
DisableHTTPS: options.DisableHTTPS,
}
return r.partitions.ResolveEndpoint(region, opt)
}
// New returns a new Resolver
func New() *Resolver {
return &Resolver{
partitions: defaultPartitions,
}
}
var partitionRegexp = struct {
Aws *regexp.Regexp
AwsCn *regexp.Regexp
AwsIso *regexp.Regexp
AwsIsoB *regexp.Regexp
AwsUsGov *regexp.Regexp
}{
Aws: regexp.MustCompile("^(us|eu|ap|sa|ca|me|af)\\-\\w+\\-\\d+$"),
AwsCn: regexp.MustCompile("^cn\\-\\w+\\-\\d+$"),
AwsIso: regexp.MustCompile("^us\\-iso\\-\\w+\\-\\d+$"),
AwsIsoB: regexp.MustCompile("^us\\-isob\\-\\w+\\-\\d+$"),
AwsUsGov: regexp.MustCompile("^us\\-gov\\-\\w+\\-\\d+$"),
}
var defaultPartitions = endpoints.Partitions{
{
ID: "aws",
Defaults: endpoints.Endpoint{
Hostname: "s3.{region}.amazonaws.com",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"s3v4"},
},
RegionRegex: partitionRegexp.Aws,
IsRegionalized: true,
Endpoints: endpoints.Endpoints{
"accesspoint-af-south-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.af-south-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-northeast-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-northeast-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-northeast-2": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-northeast-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-northeast-3": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-northeast-3.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-south-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-south-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-southeast-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-southeast-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ap-southeast-2": endpoints.Endpoint{
Hostname: "s3-accesspoint.ap-southeast-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-ca-central-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.ca-central-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-central-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-central-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-north-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-north-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-south-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-south-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-west-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-west-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-west-2": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-west-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-eu-west-3": endpoints.Endpoint{
Hostname: "s3-accesspoint.eu-west-3.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-me-south-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.me-south-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-sa-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.sa-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-us-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-us-east-2": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-east-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-us-west-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-west-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-us-west-2": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-west-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"af-south-1": endpoints.Endpoint{},
"ap-east-1": endpoints.Endpoint{},
"ap-northeast-1": endpoints.Endpoint{
Hostname: "s3.ap-northeast-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"ap-northeast-2": endpoints.Endpoint{},
"ap-northeast-3": endpoints.Endpoint{},
"ap-south-1": endpoints.Endpoint{},
"ap-southeast-1": endpoints.Endpoint{
Hostname: "s3.ap-southeast-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"ap-southeast-2": endpoints.Endpoint{
Hostname: "s3.ap-southeast-2.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"aws-global": endpoints.Endpoint{
Hostname: "s3.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
CredentialScope: endpoints.CredentialScope{
Region: "us-east-1",
},
},
"ca-central-1": endpoints.Endpoint{},
"eu-central-1": endpoints.Endpoint{},
"eu-north-1": endpoints.Endpoint{},
"eu-south-1": endpoints.Endpoint{},
"eu-west-1": endpoints.Endpoint{
Hostname: "s3.eu-west-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"eu-west-2": endpoints.Endpoint{},
"eu-west-3": endpoints.Endpoint{},
"fips-accesspoint-ca-central-1": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.ca-central-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-east-2": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-east-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-west-1": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-west-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-west-2": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-west-2.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"me-south-1": endpoints.Endpoint{},
"s3-external-1": endpoints.Endpoint{
Hostname: "s3-external-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
CredentialScope: endpoints.CredentialScope{
Region: "us-east-1",
},
},
"sa-east-1": endpoints.Endpoint{
Hostname: "s3.sa-east-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"us-east-1": endpoints.Endpoint{
Hostname: "s3.us-east-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"us-east-2": endpoints.Endpoint{},
"us-west-1": endpoints.Endpoint{
Hostname: "s3.us-west-1.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
"us-west-2": endpoints.Endpoint{
Hostname: "s3.us-west-2.amazonaws.com",
SignatureVersions: []string{"s3", "s3v4"},
},
},
},
{
ID: "aws-cn",
Defaults: endpoints.Endpoint{
Hostname: "s3.{region}.amazonaws.com.cn",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"s3v4"},
},
RegionRegex: partitionRegexp.AwsCn,
IsRegionalized: true,
Endpoints: endpoints.Endpoints{
"accesspoint-cn-north-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.cn-north-1.amazonaws.com.cn",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-cn-northwest-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.cn-northwest-1.amazonaws.com.cn",
SignatureVersions: []string{"s3v4"},
},
"cn-north-1": endpoints.Endpoint{},
"cn-northwest-1": endpoints.Endpoint{},
},
},
{
ID: "aws-iso",
Defaults: endpoints.Endpoint{
Hostname: "s3.{region}.c2s.ic.gov",
Protocols: []string{"https"},
SignatureVersions: []string{"s3v4"},
},
RegionRegex: partitionRegexp.AwsIso,
IsRegionalized: true,
Endpoints: endpoints.Endpoints{
"us-iso-east-1": endpoints.Endpoint{
Protocols: []string{"http", "https"},
SignatureVersions: []string{"s3v4"},
},
},
},
{
ID: "aws-iso-b",
Defaults: endpoints.Endpoint{
Hostname: "s3.{region}.sc2s.sgov.gov",
Protocols: []string{"http", "https"},
SignatureVersions: []string{"s3v4"},
},
RegionRegex: partitionRegexp.AwsIsoB,
IsRegionalized: true,
Endpoints: endpoints.Endpoints{
"us-isob-east-1": endpoints.Endpoint{},
},
},
{
ID: "aws-us-gov",
Defaults: endpoints.Endpoint{
Hostname: "s3.{region}.amazonaws.com",
Protocols: []string{"https"},
SignatureVersions: []string{"s3", "s3v4"},
},
RegionRegex: partitionRegexp.AwsUsGov,
IsRegionalized: true,
Endpoints: endpoints.Endpoints{
"accesspoint-us-gov-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-gov-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"accesspoint-us-gov-west-1": endpoints.Endpoint{
Hostname: "s3-accesspoint.us-gov-west-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-gov-east-1": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-gov-east-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-accesspoint-us-gov-west-1": endpoints.Endpoint{
Hostname: "s3-accesspoint-fips.us-gov-west-1.amazonaws.com",
SignatureVersions: []string{"s3v4"},
},
"fips-us-gov-west-1": endpoints.Endpoint{
Hostname: "s3-fips.us-gov-west-1.amazonaws.com",
CredentialScope: endpoints.CredentialScope{
Region: "us-gov-west-1",
},
},
"us-gov-east-1": endpoints.Endpoint{
Hostname: "s3.us-gov-east-1.amazonaws.com",
Protocols: []string{"http", "https"},
},
"us-gov-west-1": endpoints.Endpoint{
Hostname: "s3.us-gov-west-1.amazonaws.com",
Protocols: []string{"http", "https"},
},
},
},
}
// GetDNSSuffix returns the dnsSuffix URL component for the given partition id
func GetDNSSuffix(id string) (string, error) {
switch {
case strings.EqualFold(id, "aws"):
return "amazonaws.com", nil
case strings.EqualFold(id, "aws-cn"):
return "amazonaws.com.cn", nil
case strings.EqualFold(id, "aws-iso"):
return "c2s.ic.gov", nil
case strings.EqualFold(id, "aws-iso-b"):
return "sc2s.sgov.gov", nil
case strings.EqualFold(id, "aws-us-gov"):
return "amazonaws.com", nil
default:
return "", fmt.Errorf("unknown partition")
}
}
// GetDNSSuffixFromRegion returns the dnsSuffix URL component for the given
// partition id
func GetDNSSuffixFromRegion(region string) (string, error) {
switch {
case partitionRegexp.Aws.MatchString(region):
return "amazonaws.com", nil
case partitionRegexp.AwsCn.MatchString(region):
return "amazonaws.com.cn", nil
case partitionRegexp.AwsIso.MatchString(region):
return "c2s.ic.gov", nil
case partitionRegexp.AwsIsoB.MatchString(region):
return "sc2s.sgov.gov", nil
case partitionRegexp.AwsUsGov.MatchString(region):
return "amazonaws.com", nil
default:
return "", fmt.Errorf("unknown region partition")
}
}

View File

@ -0,0 +1,141 @@
package v4a
import (
"context"
"crypto/ecdsa"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)
// Credentials is Context, ECDSA, and Optional Session Token that can be used
// to sign requests using SigV4a
type Credentials struct {
Context string
PrivateKey *ecdsa.PrivateKey
SessionToken string
// Time the credentials will expire.
CanExpire bool
Expires time.Time
}
// Expired returns if the credentials have expired.
func (v Credentials) Expired() bool {
if v.CanExpire {
return !v.Expires.After(sdk.NowTime())
}
return false
}
// HasKeys returns if the credentials keys are set.
func (v Credentials) HasKeys() bool {
return len(v.Context) > 0 && v.PrivateKey != nil
}
// SymmetricCredentialAdaptor wraps a SigV4 AccessKey/SecretKey provider and adapts the credentials
// to a ECDSA PrivateKey for signing with SiV4a
type SymmetricCredentialAdaptor struct {
SymmetricProvider aws.CredentialsProvider
asymmetric atomic.Value
m sync.Mutex
}
// Retrieve retrieves symmetric credentials from the underlying provider.
func (s *SymmetricCredentialAdaptor) Retrieve(ctx context.Context) (aws.Credentials, error) {
symCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return aws.Credentials{}, nil
}
if asymCreds := s.getCreds(); asymCreds == nil {
return symCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
asymCreds := s.getCreds()
if asymCreds == nil {
return symCreds, nil
}
// if the context does not match the access key id clear it
if asymCreds.Context != symCreds.AccessKeyID {
s.asymmetric.Store((*Credentials)(nil))
}
return symCreds, nil
}
// RetrievePrivateKey returns credentials suitable for SigV4a signing
func (s *SymmetricCredentialAdaptor) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
symmetricCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return Credentials{}, fmt.Errorf("failed to retrieve symmetric credentials: %v", err)
}
privateKey, err := deriveKeyFromAccessKeyPair(symmetricCreds.AccessKeyID, symmetricCreds.SecretAccessKey)
if err != nil {
return Credentials{}, fmt.Errorf("failed to derive assymetric key from credentials")
}
creds := Credentials{
Context: symmetricCreds.AccessKeyID,
PrivateKey: privateKey,
SessionToken: symmetricCreds.SessionToken,
CanExpire: symmetricCreds.CanExpire,
Expires: symmetricCreds.Expires,
}
s.asymmetric.Store(&creds)
return creds, nil
}
func (s *SymmetricCredentialAdaptor) getCreds() *Credentials {
v := s.asymmetric.Load()
if v == nil {
return nil
}
c := v.(*Credentials)
if c != nil && c.HasKeys() && !c.Expired() {
return c
}
return nil
}
func (s *SymmetricCredentialAdaptor) retrieveFromSymmetricProvider(ctx context.Context) (aws.Credentials, error) {
credentials, err := s.SymmetricProvider.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, err
}
return credentials, nil
}
// CredentialsProvider is the interface for a provider to retrieve credentials
// to sign requests with.
type CredentialsProvider interface {
RetrievePrivateKey(context.Context) (Credentials, error)
}

View File

@ -0,0 +1,17 @@
package v4a
import "fmt"
// SigningError indicates an error condition occurred while performing SigV4a signing
type SigningError struct {
Err error
}
func (e *SigningError) Error() string {
return fmt.Sprintf("failed to sign request: %v", e.Err)
}
// Unwrap returns the underlying error cause
func (e *SigningError) Unwrap() error {
return e.Err
}

View File

@ -0,0 +1,30 @@
package crypto
import "fmt"
// ConstantTimeByteCompare is a constant-time byte comparison of x and y. This function performs an absolute comparison
// if the two byte slices assuming they represent a big-endian number.
//
// error if len(x) != len(y)
// -1 if x < y
// 0 if x == y
// +1 if x > y
func ConstantTimeByteCompare(x, y []byte) (int, error) {
if len(x) != len(y) {
return 0, fmt.Errorf("slice lengths do not match")
}
xLarger, yLarger := 0, 0
for i := 0; i < len(x); i++ {
xByte, yByte := int(x[i]), int(y[i])
x := ((yByte - xByte) >> 8) & 1
y := ((xByte - yByte) >> 8) & 1
xLarger |= x &^ yLarger
yLarger |= y &^ xLarger
}
return xLarger - yLarger, nil
}

View File

@ -0,0 +1,113 @@
package crypto
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"encoding/asn1"
"encoding/binary"
"fmt"
"hash"
"math"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// ECDSAKey takes the given elliptic curve, and private key (d) byte slice
// and returns the private ECDSA key.
func ECDSAKey(curve elliptic.Curve, d []byte) *ecdsa.PrivateKey {
return ECDSAKeyFromPoint(curve, (&big.Int{}).SetBytes(d))
}
// ECDSAKeyFromPoint takes the given elliptic curve and point and returns the
// private and public keypair
func ECDSAKeyFromPoint(curve elliptic.Curve, d *big.Int) *ecdsa.PrivateKey {
pX, pY := curve.ScalarBaseMult(d.Bytes())
privKey := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: pX,
Y: pY,
},
D: d,
}
return privKey
}
// ECDSAPublicKey takes the provide curve and (x, y) coordinates and returns
// *ecdsa.PublicKey. Returns an error if the given points are not on the curve.
func ECDSAPublicKey(curve elliptic.Curve, x, y []byte) (*ecdsa.PublicKey, error) {
xPoint := (&big.Int{}).SetBytes(x)
yPoint := (&big.Int{}).SetBytes(y)
if !curve.IsOnCurve(xPoint, yPoint) {
return nil, fmt.Errorf("point(%v, %v) is not on the given curve", xPoint.String(), yPoint.String())
}
return &ecdsa.PublicKey{
Curve: curve,
X: xPoint,
Y: yPoint,
}, nil
}
// VerifySignature takes the provided public key, hash, and asn1 encoded signature and returns
// whether the given signature is valid.
func VerifySignature(key *ecdsa.PublicKey, hash []byte, signature []byte) (bool, error) {
var ecdsaSignature ecdsaSignature
_, err := asn1.Unmarshal(signature, &ecdsaSignature)
if err != nil {
return false, err
}
return ecdsa.Verify(key, hash, ecdsaSignature.R, ecdsaSignature.S), nil
}
// HMACKeyDerivation provides an implementation of a NIST-800-108 of a KDF (Key Derivation Function) in Counter Mode.
// For the purposes of this implantation HMAC is used as the PRF (Pseudorandom function), where the value of
// `r` is defined as a 4 byte counter.
func HMACKeyDerivation(hash func() hash.Hash, bitLen int, key []byte, label, context []byte) ([]byte, error) {
// verify that we won't overflow the counter
n := int64(math.Ceil((float64(bitLen) / 8) / float64(hash().Size())))
if n > 0x7FFFFFFF {
return nil, fmt.Errorf("unable to derive key of size %d using 32-bit counter", bitLen)
}
// verify the requested bit length is not larger then the length encoding size
if int64(bitLen) > 0x7FFFFFFF {
return nil, fmt.Errorf("bitLen is greater than 32-bits")
}
fixedInput := bytes.NewBuffer(nil)
fixedInput.Write(label)
fixedInput.WriteByte(0x00)
fixedInput.Write(context)
if err := binary.Write(fixedInput, binary.BigEndian, int32(bitLen)); err != nil {
return nil, fmt.Errorf("failed to write bit length to fixed input string: %v", err)
}
var output []byte
h := hmac.New(hash, key)
for i := int64(1); i <= n; i++ {
h.Reset()
if err := binary.Write(h, binary.BigEndian, int32(i)); err != nil {
return nil, err
}
_, err := h.Write(fixedInput.Bytes())
if err != nil {
return nil, err
}
output = append(output, h.Sum(nil)...)
}
return output[:bitLen/8], nil
}

View File

@ -0,0 +1,36 @@
package v4
const (
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
// UnsignedPayload indicates that the request payload body is unsigned
UnsignedPayload = "UNSIGNED-PAYLOAD"
// AmzAlgorithmKey indicates the signing algorithm
AmzAlgorithmKey = "X-Amz-Algorithm"
// AmzSecurityTokenKey indicates the security token to be used with temporary credentials
AmzSecurityTokenKey = "X-Amz-Security-Token"
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'
AmzDateKey = "X-Amz-Date"
// AmzCredentialKey is the access key ID and credential scope
AmzCredentialKey = "X-Amz-Credential"
// AmzSignedHeadersKey is the set of headers signed for the request
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
// AmzSignatureKey is the query parameter to store the SigV4 signature
AmzSignatureKey = "X-Amz-Signature"
// TimeFormat is the time format to be used in the X-Amz-Date header or query parameter
TimeFormat = "20060102T150405Z"
// ShortTimeFormat is the shorten time format used in the credential scope
ShortTimeFormat = "20060102"
// ContentSHAKey is the SHA256 of request body
ContentSHAKey = "X-Amz-Content-Sha256"
)

View File

@ -0,0 +1,82 @@
package v4
import (
sdkstrings "github.com/aws/aws-sdk-go-v2/internal/strings"
)
// Rules houses a set of Rule needed for validation of a
// string value
type Rules []Rule
// Rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that Rule
type Rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r Rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// MapRule generic Rule for maps
type MapRule map[string]struct{}
// IsValid for the map Rule satisfies whether it exists in the map
func (m MapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// AllowList is a generic Rule for whitelisting
type AllowList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (w AllowList) IsValid(value string) bool {
return w.Rule.IsValid(value)
}
// DenyList is a generic Rule for blacklisting
type DenyList struct {
Rule
}
// IsValid for AllowList checks if the value is within the AllowList
func (b DenyList) IsValid(value string) bool {
return !b.Rule.IsValid(value)
}
// Patterns is a list of strings to match against
type Patterns []string
// IsValid for Patterns checks each pattern and returns if a match has
// been found
func (p Patterns) IsValid(value string) bool {
for _, pattern := range p {
if sdkstrings.HasPrefixFold(value, pattern) {
return true
}
}
return false
}
// InclusiveRules rules allow for rules to depend on one another
type InclusiveRules []Rule
// IsValid will return true if all rules are true
func (r InclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}

View File

@ -0,0 +1,67 @@
package v4
// IgnoredHeaders is a list of headers that are ignored during signing
var IgnoredHeaders = Rules{
DenyList{
MapRule{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// RequiredSignedHeaders is a whitelist for Build canonical headers.
var RequiredSignedHeaders = Rules{
AllowList{
MapRule{
"Cache-Control": struct{}{},
"Content-Disposition": struct{}{},
"Content-Encoding": struct{}{},
"Content-Language": struct{}{},
"Content-Md5": struct{}{},
"Content-Type": struct{}{},
"Expires": struct{}{},
"If-Match": struct{}{},
"If-Modified-Since": struct{}{},
"If-None-Match": struct{}{},
"If-Unmodified-Since": struct{}{},
"Range": struct{}{},
"X-Amz-Acl": struct{}{},
"X-Amz-Copy-Source": struct{}{},
"X-Amz-Copy-Source-If-Match": struct{}{},
"X-Amz-Copy-Source-If-Modified-Since": struct{}{},
"X-Amz-Copy-Source-If-None-Match": struct{}{},
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{},
"X-Amz-Copy-Source-Range": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Grant-Full-control": struct{}{},
"X-Amz-Grant-Read": struct{}{},
"X-Amz-Grant-Read-Acp": struct{}{},
"X-Amz-Grant-Write": struct{}{},
"X-Amz-Grant-Write-Acp": struct{}{},
"X-Amz-Metadata-Directive": struct{}{},
"X-Amz-Mfa": struct{}{},
"X-Amz-Request-Payer": struct{}{},
"X-Amz-Server-Side-Encryption": struct{}{},
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Storage-Class": struct{}{},
"X-Amz-Website-Redirect-Location": struct{}{},
"X-Amz-Content-Sha256": struct{}{},
"X-Amz-Tagging": struct{}{},
},
},
Patterns{"X-Amz-Meta-"},
}
// AllowedQueryHoisting is a whitelist for Build query headers. The boolean value
// represents whether or not it is a pattern.
var AllowedQueryHoisting = InclusiveRules{
DenyList{RequiredSignedHeaders},
Patterns{"X-Amz-"},
}

View File

@ -0,0 +1,13 @@
package v4
import (
"crypto/hmac"
"crypto/sha256"
)
// HMACSHA256 computes a HMAC-SHA256 of data given the provided key.
func HMACSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}

View File

@ -0,0 +1,75 @@
package v4
import (
"net/http"
"strings"
)
// SanitizeHostForHeader removes default port from host and updates request.Host
func SanitizeHostForHeader(r *http.Request) {
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View File

@ -0,0 +1,36 @@
package v4
import "time"
// SigningTime provides a wrapper around a time.Time which provides cached values for SigV4 signing.
type SigningTime struct {
time.Time
timeFormat string
shortTimeFormat string
}
// NewSigningTime creates a new SigningTime given a time.Time
func NewSigningTime(t time.Time) SigningTime {
return SigningTime{
Time: t,
}
}
// TimeFormat provides a time formatted in the X-Amz-Date format.
func (m *SigningTime) TimeFormat() string {
return m.format(&m.timeFormat, TimeFormat)
}
// ShortTimeFormat provides a time formatted of 20060102.
func (m *SigningTime) ShortTimeFormat() string {
return m.format(&m.shortTimeFormat, ShortTimeFormat)
}
func (m *SigningTime) format(target *string, format string) string {
if len(*target) > 0 {
return *target
}
v := m.Time.Format(format)
*target = v
return v
}

View File

@ -0,0 +1,64 @@
package v4
import (
"net/url"
"strings"
)
const doubleSpace = " "
// StripExcessSpaces will rewrite the passed in slice's string values to not
// contain muliple side-by-side spaces.
func StripExcessSpaces(str string) string {
var j, k, l, m, spaces int
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
return str
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
return string(buf[:m])
}
// GetURIPath returns the escaped URI component from the provided URL
func GetURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View File

@ -0,0 +1,105 @@
package v4a
import (
"context"
"fmt"
"net/http"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// HTTPSigner is SigV4a HTTP signer implementation
type HTTPSigner interface {
SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optfns ...func(*SignerOptions)) error
}
// SignHTTPRequestMiddlewareOptions is the middleware options for constructing a SignHTTPRequestMiddleware.
type SignHTTPRequestMiddlewareOptions struct {
Credentials CredentialsProvider
Signer HTTPSigner
LogSigning bool
}
// SignHTTPRequestMiddleware is a middleware for signing an HTTP request using SigV4a.
type SignHTTPRequestMiddleware struct {
credentials CredentialsProvider
signer HTTPSigner
logSigning bool
}
// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given SignHTTPRequestMiddlewareOptions.
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
return &SignHTTPRequestMiddleware{
credentials: options.Credentials,
signer: options.Signer,
logSigning: options.LogSigning,
}
}
// ID the middleware identifier.
func (s *SignHTTPRequestMiddleware) ID() string {
return "Signing"
}
// HandleFinalize signs an HTTP request using SigV4a.
func (s *SignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
if !hasCredentialProvider(s.credentials) {
return next.HandleFinalize(ctx, in)
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unexpected request middleware type %T", in.Request)
}
signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
payloadHash := v4.GetPayloadHash(ctx)
if len(payloadHash) == 0 {
return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
}
credentials, err := s.credentials.RetrievePrivateKey(ctx)
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
}
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, []string{signingRegion}, time.Now().UTC(), func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
}
return next.HandleFinalize(ctx, in)
}
func hasCredentialProvider(p CredentialsProvider) bool {
if p == nil {
return false
}
return true
}
// RegisterSigningMiddleware registers the SigV4a signing middleware to the stack. If a signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterSigningMiddleware(stack *middleware.Stack, signingMiddleware *SignHTTPRequestMiddleware) (err error) {
const signedID = "Signing"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}

View File

@ -0,0 +1,117 @@
package v4a
import (
"context"
"fmt"
"net/http"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/smithy-go/middleware"
smithyHTTP "github.com/aws/smithy-go/transport/http"
)
// HTTPPresigner is an interface to a SigV4a signer that can sign create a
// presigned URL for a HTTP requests.
type HTTPPresigner interface {
PresignHTTP(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error)
}
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
type PresignHTTPRequestMiddlewareOptions struct {
CredentialsProvider CredentialsProvider
Presigner HTTPPresigner
LogSigning bool
}
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
// presigned URL for an HTTP request.
//
// Will short circuit the middleware stack and not forward onto the next
// Finalize handler.
type PresignHTTPRequestMiddleware struct {
credentialsProvider CredentialsProvider
presigner HTTPPresigner
logSigning bool
}
// NewPresignHTTPRequestMiddleware returns a new PresignHTTPRequestMiddleware
// initialized with the presigner.
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
return &PresignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
presigner: options.Presigner,
logSigning: options.LogSigning,
}
}
// ID provides the middleware ID.
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
// HandleFinalize will take the provided input and create a presigned url for
// the http request using the SigV4 presign authentication scheme.
func (s *PresignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyHTTP.Request)
if !ok {
return out, metadata, &SigningError{
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
}
}
httpReq := req.Build(ctx)
if !hasCredentialProvider(s.credentialsProvider) {
out.Result = &v4.PresignedHTTPRequest{
URL: httpReq.URL.String(),
Method: httpReq.Method,
SignedHeader: http.Header{},
}
return out, metadata, nil
}
signingName := awsmiddleware.GetSigningName(ctx)
signingRegion := awsmiddleware.GetSigningRegion(ctx)
payloadHash := v4.GetPayloadHash(ctx)
if len(payloadHash) == 0 {
return out, metadata, &SigningError{
Err: fmt.Errorf("computed payload hash missing from context"),
}
}
credentials, err := s.credentialsProvider.RetrievePrivateKey(ctx)
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to retrieve credentials: %w", err),
}
}
u, h, err := s.presigner.PresignHTTP(ctx, credentials,
httpReq, payloadHash, signingName, []string{signingRegion}, sdk.NowTime(),
func(o *SignerOptions) {
o.Logger = middleware.GetLogger(ctx)
o.LogSigning = s.logSigning
})
if err != nil {
return out, metadata, &SigningError{
Err: fmt.Errorf("failed to sign http request, %w", err),
}
}
out.Result = &v4.PresignedHTTPRequest{
URL: u,
Method: httpReq.Method,
SignedHeader: h,
}
return out, metadata, nil
}

View File

@ -0,0 +1,514 @@
// TODO(GOSDK-1220): This signer has removed the conceptual knowledge of UNSIGNED-PAYLOAD and X-Amz-Content-Sha256
package v4a
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"math/big"
"net/http"
"net/textproto"
"net/url"
"sort"
"strconv"
"strings"
"time"
signerCrypto "github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a/internal/crypto"
v4Internal "github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a/internal/v4"
"github.com/aws/smithy-go/encoding/httpbinding"
"github.com/aws/smithy-go/logging"
)
const (
// AmzRegionSetKey represents the region set header used for sigv4a
AmzRegionSetKey = "X-Amz-Region-Set"
amzAlgorithmKey = v4Internal.AmzAlgorithmKey
amzSecurityTokenKey = v4Internal.AmzSecurityTokenKey
amzDateKey = v4Internal.AmzDateKey
amzCredentialKey = v4Internal.AmzCredentialKey
amzSignedHeadersKey = v4Internal.AmzSignedHeadersKey
authorizationHeader = "Authorization"
signingAlgorithm = "AWS4-ECDSA-P256-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
// EmptyStringSHA256 is a hex encoded SHA-256 hash of an empty string
EmptyStringSHA256 = v4Internal.EmptyStringSHA256
// Version of signing v4a
Version = "SigV4A"
)
var (
p256 elliptic.Curve
nMinusTwoP256 *big.Int
one = new(big.Int).SetInt64(1)
)
func init() {
// Ensure the elliptic curve parameters are initialized on package import rather then on first usage
p256 = elliptic.P256()
nMinusTwoP256 = new(big.Int).SetBytes(p256.Params().N.Bytes())
nMinusTwoP256 = nMinusTwoP256.Sub(nMinusTwoP256, new(big.Int).SetInt64(2))
}
// SignerOptions is the SigV4a signing options for constructing a Signer.
type SignerOptions struct {
Logger logging.Logger
LogSigning bool
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
// request header to the request's query string. This is most commonly used
// with pre-signed requests preventing headers from being added to the
// request's query string.
DisableHeaderHoisting bool
// Disables the automatic escaping of the URI path of the request for the
// siganture's canonical string's path. For services that do not need additional
// escaping then use this to disable the signer escaping the path.
//
// S3 is an example of a service that does not need additional escaping.
//
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
}
// Signer is a SigV4a HTTP signing implementation
type Signer struct {
options SignerOptions
}
// NewSigner constructs a SigV4a Signer.
func NewSigner(optFns ...func(*SignerOptions)) *Signer {
options := SignerOptions{}
for _, fn := range optFns {
fn(&options)
}
return &Signer{options: options}
}
// deriveKeyFromAccessKeyPair derives a NIST P-256 PrivateKey from the given
// IAM AccessKey and SecretKey pair.
//
// Based on FIPS.186-4 Appendix B.4.2
func deriveKeyFromAccessKeyPair(accessKey, secretKey string) (*ecdsa.PrivateKey, error) {
params := p256.Params()
bitLen := params.BitSize // Testing random candidates does not require an additional 64 bits
counter := 0x01
buffer := make([]byte, 1+len(accessKey)) // 1 byte counter + len(accessKey)
kdfContext := bytes.NewBuffer(buffer)
inputKey := append([]byte("AWS4A"), []byte(secretKey)...)
d := new(big.Int)
for {
kdfContext.Reset()
kdfContext.WriteString(accessKey)
kdfContext.WriteByte(byte(counter))
key, err := signerCrypto.HMACKeyDerivation(sha256.New, bitLen, inputKey, []byte(signingAlgorithm), kdfContext.Bytes())
if err != nil {
return nil, err
}
// Check key first before calling SetBytes if key key is in fact a valid candidate.
// This ensures the byte slice is the correct length (32-bytes) to compare in constant-time
cmp, err := signerCrypto.ConstantTimeByteCompare(key, nMinusTwoP256.Bytes())
if err != nil {
return nil, err
}
if cmp == -1 {
d.SetBytes(key)
break
}
counter++
if counter > 0xFF {
return nil, fmt.Errorf("exhausted single byte external counter")
}
}
d = d.Add(d, one)
priv := new(ecdsa.PrivateKey)
priv.PublicKey.Curve = p256
priv.D = d
priv.PublicKey.X, priv.PublicKey.Y = p256.ScalarBaseMult(d.Bytes())
return priv, nil
}
type httpSigner struct {
Request *http.Request
ServiceName string
RegionSet []string
Time time.Time
Credentials Credentials
IsPreSign bool
Logger logging.Logger
Debug bool
// PayloadHash is the hex encoded SHA-256 hash of the request payload
// If len(PayloadHash) == 0 the signer will attempt to send the request
// as an unsigned payload. Note: Unsigned payloads only work for a subset of services.
PayloadHash string
DisableHeaderHoisting bool
DisableURIPathEscaping bool
}
// SignHTTP takes the provided http.Request, payload hash, service, regionSet, and time and signs using SigV4a.
// The passed in request will be modified in place.
func (s *Signer) SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
RegionSet: regionSet,
Credentials: credentials,
Time: signingTime.UTC(),
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
}
signedRequest, err := signer.Build()
if err != nil {
return err
}
logHTTPSigningInfo(ctx, options, signedRequest)
return nil
}
// PresignHTTP takes the provided http.Request, payload hash, service, regionSet, and time and presigns using SigV4a
// Returns the presigned URL along with the headers that were signed with the request.
//
// PresignHTTP will not set the expires time of the presigned request
// automatically. To specify the expire duration for a request add the
// "X-Amz-Expires" query parameter on the request with the value as the
// duration in seconds the presigned URL should be considered valid for. This
// parameter is not used by all AWS services, and is most notable used by
// Amazon S3 APIs.
func (s *Signer) PresignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) (signedURI string, signedHeaders http.Header, err error) {
options := s.options
for _, fn := range optFns {
fn(&options)
}
signer := &httpSigner{
Request: r,
PayloadHash: payloadHash,
ServiceName: service,
RegionSet: regionSet,
Credentials: credentials,
Time: signingTime.UTC(),
IsPreSign: true,
DisableHeaderHoisting: options.DisableHeaderHoisting,
DisableURIPathEscaping: options.DisableURIPathEscaping,
}
signedRequest, err := signer.Build()
if err != nil {
return "", nil, err
}
logHTTPSigningInfo(ctx, options, signedRequest)
signedHeaders = make(http.Header)
// For the signed headers we canonicalize the header keys in the returned map.
// This avoids situations where can standard library double headers like host header. For example the standard
// library will set the Host header, even if it is present in lower-case form.
for k, v := range signedRequest.SignedHeaders {
key := textproto.CanonicalMIMEHeaderKey(k)
signedHeaders[key] = append(signedHeaders[key], v...)
}
return signedRequest.Request.URL.String(), signedHeaders, nil
}
func (s *httpSigner) setRequiredSigningFields(headers http.Header, query url.Values) {
amzDate := s.Time.Format(timeFormat)
if s.IsPreSign {
query.Set(AmzRegionSetKey, strings.Join(s.RegionSet, ","))
query.Set(amzDateKey, amzDate)
query.Set(amzAlgorithmKey, signingAlgorithm)
if len(s.Credentials.SessionToken) > 0 {
query.Set(amzSecurityTokenKey, s.Credentials.SessionToken)
}
return
}
headers.Set(AmzRegionSetKey, strings.Join(s.RegionSet, ","))
headers.Set(amzDateKey, amzDate)
if len(s.Credentials.SessionToken) > 0 {
headers.Set(amzSecurityTokenKey, s.Credentials.SessionToken)
}
}
func (s *httpSigner) Build() (signedRequest, error) {
req := s.Request
query := req.URL.Query()
headers := req.Header
s.setRequiredSigningFields(headers, query)
// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}
v4Internal.SanitizeHostForHeader(req)
credentialScope := s.buildCredentialScope()
credentialStr := s.Credentials.Context + "/" + credentialScope
if s.IsPreSign {
query.Set(amzCredentialKey, credentialStr)
}
unsignedHeaders := headers
if s.IsPreSign && !s.DisableHeaderHoisting {
urlValues := url.Values{}
urlValues, unsignedHeaders = buildQuery(v4Internal.AllowedQueryHoisting, unsignedHeaders)
for k := range urlValues {
query[k] = urlValues[k]
}
}
host := req.URL.Host
if len(req.Host) > 0 {
host = req.Host
}
signedHeaders, signedHeadersStr, canonicalHeaderStr := s.buildCanonicalHeaders(host, v4Internal.IgnoredHeaders, unsignedHeaders, s.Request.ContentLength)
if s.IsPreSign {
query.Set(amzSignedHeadersKey, signedHeadersStr)
}
rawQuery := strings.Replace(query.Encode(), "+", "%20", -1)
canonicalURI := v4Internal.GetURIPath(req.URL)
if !s.DisableURIPathEscaping {
canonicalURI = httpbinding.EscapePath(canonicalURI, false)
}
canonicalString := s.buildCanonicalString(
req.Method,
canonicalURI,
rawQuery,
signedHeadersStr,
canonicalHeaderStr,
)
strToSign := s.buildStringToSign(credentialScope, canonicalString)
signingSignature, err := s.buildSignature(strToSign)
if err != nil {
return signedRequest{}, err
}
if s.IsPreSign {
rawQuery += "&X-Amz-Signature=" + signingSignature
} else {
headers[authorizationHeader] = append(headers[authorizationHeader][:0], buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature))
}
req.URL.RawQuery = rawQuery
return signedRequest{
Request: req,
SignedHeaders: signedHeaders,
CanonicalString: canonicalString,
StringToSign: strToSign,
PreSigned: s.IsPreSign,
}, nil
}
func buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature string) string {
const credential = "Credential="
const signedHeaders = "SignedHeaders="
const signature = "Signature="
const commaSpace = ", "
var parts strings.Builder
parts.Grow(len(signingAlgorithm) + 1 +
len(credential) + len(credentialStr) + len(commaSpace) +
len(signedHeaders) + len(signedHeadersStr) + len(commaSpace) +
len(signature) + len(signingSignature),
)
parts.WriteString(signingAlgorithm)
parts.WriteRune(' ')
parts.WriteString(credential)
parts.WriteString(credentialStr)
parts.WriteString(commaSpace)
parts.WriteString(signedHeaders)
parts.WriteString(signedHeadersStr)
parts.WriteString(commaSpace)
parts.WriteString(signature)
parts.WriteString(signingSignature)
return parts.String()
}
func (s *httpSigner) buildCredentialScope() string {
return strings.Join([]string{
s.Time.Format(shortTimeFormat),
s.ServiceName,
"aws4_request",
}, "/")
}
func buildQuery(r v4Internal.Rule, header http.Header) (url.Values, http.Header) {
query := url.Values{}
unsignedHeaders := http.Header{}
for k, h := range header {
if r.IsValid(k) {
query[k] = h
} else {
unsignedHeaders[k] = h
}
}
return query, unsignedHeaders
}
func (s *httpSigner) buildCanonicalHeaders(host string, rule v4Internal.Rule, header http.Header, length int64) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
signed = make(http.Header)
var headers []string
const hostHeader = "host"
headers = append(headers, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)
if length > 0 {
const contentLengthHeader = "content-length"
headers = append(headers, contentLengthHeader)
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
}
for k, v := range header {
if !rule.IsValid(k) {
continue // ignored header
}
lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}
sort.Strings(headers)
signedHeaders = strings.Join(headers, ";")
var canonicalHeaders strings.Builder
n := len(headers)
const colon = ':'
for i := 0; i < n; i++ {
if headers[i] == hostHeader {
canonicalHeaders.WriteString(hostHeader)
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(v4Internal.StripExcessSpaces(host))
} else {
canonicalHeaders.WriteString(headers[i])
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(strings.Join(signed[headers[i]], ","))
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()
return signed, signedHeaders, canonicalHeadersStr
}
func (s *httpSigner) buildCanonicalString(method, uri, query, signedHeaders, canonicalHeaders string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
s.PayloadHash,
}, "\n")
}
func (s *httpSigner) buildStringToSign(credentialScope, canonicalRequestString string) string {
return strings.Join([]string{
signingAlgorithm,
s.Time.Format(timeFormat),
credentialScope,
hex.EncodeToString(makeHash(sha256.New(), []byte(canonicalRequestString))),
}, "\n")
}
func makeHash(hash hash.Hash, b []byte) []byte {
hash.Reset()
hash.Write(b)
return hash.Sum(nil)
}
func (s *httpSigner) buildSignature(strToSign string) (string, error) {
sig, err := s.Credentials.PrivateKey.Sign(rand.Reader, makeHash(sha256.New(), []byte(strToSign)), crypto.SHA256)
if err != nil {
return "", err
}
return hex.EncodeToString(sig), nil
}
const logSignInfoMsg = `Request Signature:
---[ CANONICAL STRING ]-----------------------------
%s
---[ STRING TO SIGN ]--------------------------------
%s%s
-----------------------------------------------------`
const logSignedURLMsg = `
---[ SIGNED URL ]------------------------------------
%s`
func logHTTPSigningInfo(ctx context.Context, options SignerOptions, r signedRequest) {
if !options.LogSigning {
return
}
signedURLMsg := ""
if r.PreSigned {
signedURLMsg = fmt.Sprintf(logSignedURLMsg, r.Request.URL.String())
}
logger := logging.WithContext(ctx, options.Logger)
logger.Logf(logging.Debug, logSignInfoMsg, r.CanonicalString, r.StringToSign, signedURLMsg)
}
type signedRequest struct {
Request *http.Request
SignedHeaders http.Header
CanonicalString string
StringToSign string
PreSigned bool
}