feat(train): add new command to interact with aws and train models

This commit is contained in:
2021-10-17 19:15:44 +02:00
parent 5436dfebc2
commit 538cea18f2
1064 changed files with 282251 additions and 89305 deletions

View File

@ -0,0 +1,87 @@
/*
Package customizations provides customizations for the Amazon S3 API client.
This package provides support for following S3 customizations
ProcessARN Middleware: processes an ARN if provided as input and updates the endpoint as per the arn type
UpdateEndpoint Middleware: resolves a custom endpoint as per s3 config options
RemoveBucket Middleware: removes a serialized bucket name from request url path
processResponseWith200Error Middleware: Deserializing response error with 200 status code
Virtual Host style url addressing
Since serializers serialize by default as path style url, we use customization
to modify the endpoint url when `UsePathStyle` option on S3Client is unset or
false. This flag will be ignored if `UseAccelerate` option is set to true.
If UseAccelerate is not enabled, and the bucket name is not a valid hostname
label, they SDK will fallback to forcing the request to be made as if
UsePathStyle was enabled. This behavior is also used if UseDualStack is enabled.
https://docs.aws.amazon.com/AmazonS3/latest/dev/dual-stack-endpoints.html#dual-stack-endpoints-description
Transfer acceleration
By default S3 Transfer acceleration support is disabled. By enabling `UseAccelerate`
option on S3Client, one can enable s3 transfer acceleration support. Transfer
acceleration only works with Virtual Host style addressing, and thus `UsePathStyle`
option if set is ignored. Transfer acceleration is not supported for S3 operations
DeleteBucket, ListBuckets, and CreateBucket.
Dualstack support
By default dualstack support for s3 client is disabled. By enabling `UseDualstack`
option on s3 client, you can enable dualstack endpoint support.
Endpoint customizations
Customizations to lookup ARN, process ARN needs to happen before request serialization.
UpdateEndpoint middleware which mutates resources based on Options such as
UseDualstack, UseAccelerate for modifying resolved endpoint are executed after
request serialization. Remove bucket middleware is executed after
an request is serialized, and removes the serialized bucket name from request path
Middleware layering:
Initialize : HTTP Request -> ARN Lookup -> Input-Validation -> Serialize step
Serialize : HTTP Request -> Process ARN -> operation serializer -> Update-Endpoint customization -> Remove-Bucket -> next middleware
Customization options:
UseARNRegion (Disabled by Default)
UsePathStyle (Disabled by Default)
UseAccelerate (Disabled by Default)
UseDualstack (Disabled by Default)
Handle Error response with 200 status code
S3 operations: CopyObject, CompleteMultipartUpload, UploadPartCopy can have an
error Response with status code 2xx. The processResponseWith200Error middleware
customizations enables SDK to check for an error within response body prior to
deserialization.
As the check for 2xx response containing an error needs to be performed earlier
than response deserialization. Since the behavior of Deserialization is in
reverse order to the other stack steps its easier to consider that "after" means
"before".
Middleware layering:
HTTP Response -> handle 200 error customization -> deserialize
*/
package customizations

View File

@ -0,0 +1,74 @@
package customizations
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"strings"
"github.com/aws/smithy-go"
smithyxml "github.com/aws/smithy-go/encoding/xml"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// HandleResponseErrorWith200Status check for S3 200 error response.
// If an s3 200 error is found, status code for the response is modified temporarily to
// 5xx response status code.
func HandleResponseErrorWith200Status(stack *middleware.Stack) error {
return stack.Deserialize.Insert(&processResponseFor200ErrorMiddleware{}, "OperationDeserializer", middleware.After)
}
// middleware to process raw response and look for error response with 200 status code
type processResponseFor200ErrorMiddleware struct{}
// ID returns the middleware ID.
func (*processResponseFor200ErrorMiddleware) ID() string {
return "S3:ProcessResponseFor200Error"
}
func (m *processResponseFor200ErrorMiddleware) HandleDeserialize(
ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
if err != nil {
return out, metadata, err
}
response, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
}
// check if response status code is 2xx.
if response.StatusCode < 200 || response.StatusCode >= 300 {
return
}
var readBuff bytes.Buffer
body := io.TeeReader(response.Body, &readBuff)
rootDecoder := xml.NewDecoder(body)
t, err := smithyxml.FetchRootElement(rootDecoder)
if err == io.EOF {
return out, metadata, &smithy.DeserializationError{
Err: fmt.Errorf("received empty response payload"),
}
}
// rewind response body
response.Body = ioutil.NopCloser(io.MultiReader(&readBuff, response.Body))
// if start tag is "Error", the response is consider error response.
if strings.EqualFold(t.Name.Local, "Error") {
// according to https://aws.amazon.com/premiumsupport/knowledge-center/s3-resolve-200-internalerror/
// 200 error responses are similar to 5xx errors.
response.StatusCode = 500
}
return out, metadata, err
}

View File

@ -0,0 +1,22 @@
package customizations
import (
"github.com/aws/smithy-go/transport/http"
"strings"
)
func updateS3HostForS3AccessPoint(req *http.Request) {
updateHostPrefix(req, "s3", s3AccessPoint)
}
func updateS3HostForS3ObjectLambda(req *http.Request) {
updateHostPrefix(req, "s3", s3ObjectLambda)
}
func updateHostPrefix(req *http.Request, oldEndpointPrefix, newEndpointPrefix string) {
host := req.URL.Host
if strings.HasPrefix(host, oldEndpointPrefix) {
// For example if oldEndpointPrefix=s3 would replace to newEndpointPrefix
req.URL.Host = newEndpointPrefix + host[len(oldEndpointPrefix):]
}
}

View File

@ -0,0 +1,49 @@
package customizations
import (
"context"
"fmt"
"strconv"
"time"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// AddExpiresOnPresignedURL represents a build middleware used to assign
// expiration on a presigned URL.
type AddExpiresOnPresignedURL struct {
// Expires is time.Duration within which presigned url should be expired.
// This should be the duration in seconds the presigned URL should be considered valid for.
// By default the S3 presigned url expires in 15 minutes ie. 900 seconds.
Expires time.Duration
}
// ID representing the middleware
func (*AddExpiresOnPresignedURL) ID() string {
return "S3:AddExpiresOnPresignedURL"
}
// HandleBuild handles the build step middleware behavior
func (m *AddExpiresOnPresignedURL) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
// if expiration is unset skip this middleware
if m.Expires == 0 {
// default to 15 * time.Minutes
m.Expires = 15 * time.Minute
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type %T", req)
}
// set S3 X-AMZ-Expires header
query := req.URL.Query()
query.Set("X-Amz-Expires", strconv.FormatInt(int64(m.Expires/time.Second), 10))
req.URL.RawQuery = query.Encode()
return next.HandleBuild(ctx, in)
}

View File

@ -0,0 +1,588 @@
package customizations
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared/arn"
s3arn "github.com/aws/aws-sdk-go-v2/service/s3/internal/arn"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/endpoints"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a"
)
const (
s3AccessPoint = "s3-accesspoint"
s3ObjectLambda = "s3-object-lambda"
)
// processARNResource is used to process an ARN resource.
type processARNResource struct {
// UseARNRegion indicates if region parsed from an ARN should be used.
UseARNRegion bool
// UseAccelerate indicates if s3 transfer acceleration is enabled
UseAccelerate bool
// UseDualstack instructs if s3 dualstack endpoint config is enabled
UseDualstack bool
// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
EndpointResolver EndpointResolver
// EndpointResolverOptions used by endpoint resolver
EndpointResolverOptions EndpointResolverOptions
// DisableMultiRegionAccessPoints indicates multi-region access point support is disabled
DisableMultiRegionAccessPoints bool
}
// ID returns the middleware ID.
func (*processARNResource) ID() string { return "S3:ProcessARNResource" }
func (m *processARNResource) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// check if arn was provided, if not skip this middleware
arnValue, ok := s3shared.GetARNResourceFromContext(ctx)
if !ok {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// parse arn into an endpoint arn wrt to service
resource, err := s3arn.ParseEndpointARN(arnValue)
if err != nil {
return out, metadata, err
}
// build a resource request struct
resourceRequest := s3shared.ResourceRequest{
Resource: resource,
UseARNRegion: m.UseARNRegion,
RequestRegion: awsmiddleware.GetRegion(ctx),
SigningRegion: awsmiddleware.GetSigningRegion(ctx),
PartitionID: awsmiddleware.GetPartitionID(ctx),
}
// switch to correct endpoint updater
switch tv := resource.(type) {
case arn.AccessPointARN:
// multi-region arns do not need to validate for cross partition request
if len(tv.Region) != 0 {
// validate resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
}
// Special handling for region-less ap-arns.
if len(tv.Region) == 0 {
// check if multi-region arn support is disabled
if m.DisableMultiRegionAccessPoints {
return out, metadata, fmt.Errorf("Invalid configuration, Multi-Region access point ARNs are disabled")
}
// Do not allow dual-stack configuration with multi-region arns.
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// fetch arn region to resolve request
resolveRegion := tv.Region
// check if request region is FIPS
if resourceRequest.UseFips() {
// Do not allow Fips support within multi-region arns.
if len(resolveRegion) == 0 {
return out, metadata, s3shared.NewClientConfiguredForFIPSError(
tv, resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// if use arn region is enabled and request signing region is not same as arn region
if m.UseARNRegion && resourceRequest.IsCrossRegion() {
// FIPS with cross region is not supported, the SDK must fail
// because there is no well defined method for SDK to construct a
// correct FIPS endpoint.
return out, metadata,
s3shared.NewClientConfiguredForCrossRegionFIPSError(
tv,
resourceRequest.PartitionID,
resourceRequest.RequestRegion,
nil,
)
}
// if use arn region is NOT set, we should use the request region
resolveRegion = resourceRequest.RequestRegion
}
var requestBuilder func(context.Context, accesspointOptions) (context.Context, error)
if len(resolveRegion) == 0 {
requestBuilder = buildMultiRegionAccessPointsRequest
} else {
requestBuilder = buildAccessPointRequest
}
// build request as per accesspoint builder
ctx, err = requestBuilder(ctx, accesspointOptions{
processARNResource: *m,
request: req,
resource: tv,
resolveRegion: resolveRegion,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
case arn.S3ObjectLambdaAccessPointARN:
// validate region for resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if dualstack
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// fetch arn region to resolve request
resolveRegion := tv.Region
if resourceRequest.UseFips() {
// if use arn region is enabled and request signing region is not same as arn region
if m.UseARNRegion && resourceRequest.IsCrossRegion() {
// FIPS with cross region is not supported, the SDK must fail
// because there is no well defined method for SDK to construct a
// correct FIPS endpoint.
return out, metadata,
s3shared.NewClientConfiguredForCrossRegionFIPSError(
tv,
resourceRequest.PartitionID,
resourceRequest.RequestRegion,
nil,
)
}
// if use arn region is NOT set, we should use the request region
resolveRegion = resourceRequest.RequestRegion
}
// build access point request
ctx, err = buildS3ObjectLambdaAccessPointRequest(ctx, accesspointOptions{
processARNResource: *m,
request: req,
resource: tv.AccessPointARN,
resolveRegion: resolveRegion,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
// process outpost accesspoint ARN
case arn.OutpostAccessPointARN:
// validate region for resource request
if err := validateRegionForResourceRequest(resourceRequest); err != nil {
return out, metadata, err
}
// check if accelerate
if m.UseAccelerate {
return out, metadata, s3shared.NewClientConfiguredForAccelerateError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if dual stack
if m.UseDualstack {
return out, metadata, s3shared.NewClientConfiguredForDualStackError(tv,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if request region is FIPS
if resourceRequest.UseFips() {
return out, metadata, s3shared.NewFIPSConfigurationError(tv, resourceRequest.PartitionID,
resourceRequest.RequestRegion, nil)
}
// build outpost access point request
ctx, err = buildOutpostAccessPointRequest(ctx, outpostAccessPointOptions{
processARNResource: *m,
resource: tv,
request: req,
partitionID: resourceRequest.PartitionID,
requestRegion: resourceRequest.RequestRegion,
})
if err != nil {
return out, metadata, err
}
default:
return out, metadata, s3shared.NewInvalidARNError(resource, nil)
}
return next.HandleSerialize(ctx, in)
}
// validate if s3 resource and request region config is compatible.
func validateRegionForResourceRequest(resourceRequest s3shared.ResourceRequest) error {
// check if resourceRequest leads to a cross partition error
v, err := resourceRequest.IsCrossPartition()
if err != nil {
return err
}
if v {
// if cross partition
return s3shared.NewClientPartitionMismatchError(resourceRequest.Resource,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
// check if resourceRequest leads to a cross region error
if !resourceRequest.AllowCrossRegion() && resourceRequest.IsCrossRegion() {
// if cross region, but not use ARN region is not enabled
return s3shared.NewClientRegionMismatchError(resourceRequest.Resource,
resourceRequest.PartitionID, resourceRequest.RequestRegion, nil)
}
return nil
}
// === Accesspoint ==========
type accesspointOptions struct {
processARNResource
request *http.Request
resource arn.AccessPointARN
resolveRegion string
partitionID string
requestRegion string
}
func buildAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := options.resolveRegion
resolveService := tv.Service
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
// Must sign with s3-object-lambda
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to "s3-accesspoint"
ctx = awsmiddleware.SetServiceID(ctx, s3AccessPoint)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
updateS3HostForS3AccessPoint(req)
ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
if err != nil {
return ctx, err
}
return ctx, nil
}
func buildS3ObjectLambdaAccessPointRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := options.resolveRegion
resolveService := tv.Service
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
// Must sign with s3-object-lambda
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to "s3-object-lambda"
ctx = awsmiddleware.SetServiceID(ctx, s3ObjectLambda)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
if endpoint.Source == aws.EndpointSourceServiceMetadata {
updateS3HostForS3ObjectLambda(req)
}
ctx, err = buildAccessPointHostPrefix(ctx, req, tv)
if err != nil {
return ctx, err
}
return ctx, nil
}
func buildMultiRegionAccessPointsRequest(ctx context.Context, options accesspointOptions) (context.Context, error) {
const s3GlobalLabel = "s3-global."
const accesspointLabel = "accesspoint."
tv := options.resource
req := options.request
resolveService := tv.Service
resolveRegion := options.requestRegion
arnPartition := tv.Partition
// resolve endpoint
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// set signing region and version for MRAP
endpoint.SigningRegion = "*"
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
ctx = SetSignerVersion(ctx, v4a.Version)
if len(endpoint.SigningName) != 0 {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
// skip arn processing, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
// modify endpoint host to use s3-global host prefix
scheme := strings.SplitN(endpoint.URL, "://", 2)
dnsSuffix, err := endpoints.GetDNSSuffix(arnPartition)
if err != nil {
return ctx, fmt.Errorf("Error determining dns suffix from arn partition, %w", err)
}
// set url as per partition
endpoint.URL = scheme[0] + "://" + s3GlobalLabel + dnsSuffix
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
// build access point host prefix
accessPointHostPrefix := tv.AccessPointName + "." + accesspointLabel
// add host prefix to url
req.URL.Host = accessPointHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = accessPointHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, fmt.Errorf("endpoint validation error: %w, when using arn %v", err, tv)
}
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
return ctx, nil
}
func buildAccessPointHostPrefix(ctx context.Context, req *http.Request, tv arn.AccessPointARN) (context.Context, error) {
// add host prefix for access point
accessPointHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "."
req.URL.Host = accessPointHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = accessPointHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, s3shared.NewInvalidARNError(tv, err)
}
return ctx, nil
}
// ====== Outpost Accesspoint ========
type outpostAccessPointOptions struct {
processARNResource
request *http.Request
resource arn.OutpostAccessPointARN
partitionID string
requestRegion string
}
func buildOutpostAccessPointRequest(ctx context.Context, options outpostAccessPointOptions) (context.Context, error) {
tv := options.resource
req := options.request
resolveRegion := tv.Region
resolveService := tv.Service
endpointsID := resolveService
if strings.EqualFold(resolveService, "s3-outposts") {
// assign endpoints ID as "S3"
endpointsID = "s3"
}
// resolve regional endpoint for resolved region.
endpoint, err := options.EndpointResolver.ResolveEndpoint(resolveRegion, options.EndpointResolverOptions)
if err != nil {
return ctx, s3shared.NewFailedToResolveEndpointError(
tv,
options.partitionID,
options.requestRegion,
err,
)
}
// assign resolved endpoint url to request url
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return ctx, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
// assign resolved service from arn as signing name
if len(endpoint.SigningName) != 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, resolveService)
}
if len(endpoint.SigningRegion) != 0 {
// redirect signer to use resolved endpoint signing name and region
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, resolveRegion)
}
// update serviceID to resolved service id
ctx = awsmiddleware.SetServiceID(ctx, resolveService)
// disable host prefix behavior
ctx = http.DisableEndpointHostPrefix(ctx, true)
// remove the serialized arn in place of /{Bucket}
ctx = setBucketToRemoveOnContext(ctx, tv.String())
// skip further customizations, if arn region resolves to a immutable endpoint
if endpoint.HostnameImmutable {
return ctx, nil
}
updateHostPrefix(req, endpointsID, resolveService)
// add host prefix for s3-outposts
outpostAPHostPrefix := tv.AccessPointName + "-" + tv.AccountID + "." + tv.OutpostID + "."
req.URL.Host = outpostAPHostPrefix + req.URL.Host
if len(req.Host) > 0 {
req.Host = outpostAPHostPrefix + req.Host
}
// validate the endpoint host
if err := http.ValidateEndpointHost(req.URL.Host); err != nil {
return ctx, s3shared.NewInvalidARNError(tv, err)
}
return ctx, nil
}

View File

@ -0,0 +1,58 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
)
// removeBucketFromPathMiddleware needs to be executed after serialize step is performed
type removeBucketFromPathMiddleware struct {
}
func (m *removeBucketFromPathMiddleware) ID() string {
return "S3:RemoveBucketFromPathMiddleware"
}
func (m *removeBucketFromPathMiddleware) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// check if a bucket removal from HTTP path is required
bucket, ok := getRemoveBucketFromPath(ctx)
if !ok {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
removeBucketFromPath(req.URL, bucket)
return next.HandleSerialize(ctx, in)
}
type removeBucketKey struct {
bucket string
}
// setBucketToRemoveOnContext sets the bucket name to be removed.
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func setBucketToRemoveOnContext(ctx context.Context, bucket string) context.Context {
return middleware.WithStackValue(ctx, removeBucketKey{}, bucket)
}
// getRemoveBucketFromPath returns the bucket name to remove from the path.
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func getRemoveBucketFromPath(ctx context.Context) (string, bool) {
v, ok := middleware.GetStackValue(ctx, removeBucketKey{}).(string)
return v, ok
}

View File

@ -0,0 +1,85 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
"net/url"
)
type s3ObjectLambdaEndpoint struct {
// whether the operation should use the s3-object-lambda endpoint
UseEndpoint bool
// use dualstack
UseDualstack bool
// use transfer acceleration
UseAccelerate bool
EndpointResolver EndpointResolver
EndpointResolverOptions EndpointResolverOptions
}
func (t *s3ObjectLambdaEndpoint) ID() string {
return "S3:ObjectLambdaEndpoint"
}
func (t *s3ObjectLambdaEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
if !t.UseEndpoint {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*http.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type: %T", in.Request)
}
if t.UseDualstack {
return out, metadata, fmt.Errorf("client configured for dualstack but not supported for operation")
}
if t.UseAccelerate {
return out, metadata, fmt.Errorf("client configured for accelerate but not supported for operation")
}
region := awsmiddleware.GetRegion(ctx)
endpoint, err := t.EndpointResolver.ResolveEndpoint(region, t.EndpointResolverOptions)
if err != nil {
return out, metadata, err
}
// Set the ServiceID and SigningName
ctx = awsmiddleware.SetServiceID(ctx, s3ObjectLambda)
if len(endpoint.SigningName) > 0 && endpoint.Source == aws.EndpointSourceCustom {
ctx = awsmiddleware.SetSigningName(ctx, endpoint.SigningName)
} else {
ctx = awsmiddleware.SetSigningName(ctx, s3ObjectLambda)
}
req.URL, err = url.Parse(endpoint.URL)
if err != nil {
return out, metadata, err
}
if len(endpoint.SigningRegion) > 0 {
ctx = awsmiddleware.SetSigningRegion(ctx, endpoint.SigningRegion)
} else {
ctx = awsmiddleware.SetSigningRegion(ctx, region)
}
if endpoint.Source == aws.EndpointSourceServiceMetadata || !endpoint.HostnameImmutable {
updateS3HostForS3ObjectLambda(req)
}
return next.HandleSerialize(ctx, in)
}

View File

@ -0,0 +1,213 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3/internal/v4a"
"github.com/aws/smithy-go/middleware"
)
type signerVersionKey struct{}
// GetSignerVersion retrieves the signer version to use for signing
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func GetSignerVersion(ctx context.Context) (v string) {
v, _ = middleware.GetStackValue(ctx, signerVersionKey{}).(string)
return v
}
// SetSignerVersion sets the signer version to be used for signing the request
//
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
// to clear all stack values.
func SetSignerVersion(ctx context.Context, version string) context.Context {
return middleware.WithStackValue(ctx, signerVersionKey{}, version)
}
// SignHTTPRequestMiddlewareOptions is the configuration options for the SignHTTPRequestMiddleware middleware.
type SignHTTPRequestMiddlewareOptions struct {
// credential provider
CredentialsProvider aws.CredentialsProvider
// log signing
LogSigning bool
// v4 signer
V4Signer v4.HTTPSigner
//v4a signer
V4aSigner v4a.HTTPSigner
}
// NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
return &SignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
v4Signer: options.V4Signer,
v4aSigner: options.V4aSigner,
logSigning: options.LogSigning,
}
}
// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation to select HTTP Signing method
type SignHTTPRequestMiddleware struct {
// credential provider
credentialsProvider aws.CredentialsProvider
// log signing
logSigning bool
// v4 signer
v4Signer v4.HTTPSigner
//v4a signer
v4aSigner v4a.HTTPSigner
}
// ID is the SignHTTPRequestMiddleware identifier
func (s *SignHTTPRequestMiddleware) ID() string {
return "Signing"
}
// HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
// fetch signer type from context
signerVersion := GetSignerVersion(ctx)
switch signerVersion {
case v4a.Version:
v4aCredentialProvider, ok := s.credentialsProvider.(v4a.CredentialsProvider)
if !ok {
return out, metadata, fmt.Errorf("invalid credential-provider provided for sigV4a Signer")
}
mw := v4a.NewSignHTTPRequestMiddleware(v4a.SignHTTPRequestMiddlewareOptions{
Credentials: v4aCredentialProvider,
Signer: s.v4aSigner,
LogSigning: s.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
default:
mw := v4.NewSignHTTPRequestMiddleware(v4.SignHTTPRequestMiddlewareOptions{
CredentialsProvider: s.credentialsProvider,
Signer: s.v4Signer,
LogSigning: s.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
}
}
// RegisterSigningMiddleware registers the wrapper signing middleware to the stack. If a signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterSigningMiddleware(stack *middleware.Stack, signingMiddleware *SignHTTPRequestMiddleware) (err error) {
const signedID = "Signing"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
type PresignHTTPRequestMiddlewareOptions struct {
CredentialsProvider aws.CredentialsProvider
V4Presigner v4.HTTPPresigner
V4aPresigner v4a.HTTPPresigner
LogSigning bool
}
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
// presigned URL for an HTTP request.
//
// Will short circuit the middleware stack and not forward onto the next
// Finalize handler.
type PresignHTTPRequestMiddleware struct {
// cred provider and signer for sigv4
credentialsProvider aws.CredentialsProvider
// sigV4 signer
v4Signer v4.HTTPPresigner
// sigV4a signer
v4aSigner v4a.HTTPPresigner
// log signing
logSigning bool
}
// NewPresignHTTPRequestMiddleware constructs a PresignHTTPRequestMiddleware using the given Signer for signing requests
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
return &PresignHTTPRequestMiddleware{
credentialsProvider: options.CredentialsProvider,
v4Signer: options.V4Presigner,
v4aSigner: options.V4aPresigner,
logSigning: options.LogSigning,
}
}
// ID provides the middleware ID.
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
// HandleFinalize will take the provided input and create a presigned url for
// the http request using the SigV4 or SigV4a presign authentication scheme.
//
// Since the signed request is not a valid HTTP request
func (p *PresignHTTPRequestMiddleware) HandleFinalize(
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
// fetch signer type from context
signerVersion := GetSignerVersion(ctx)
switch signerVersion {
case v4a.Version:
v4aCredentialProvider, ok := p.credentialsProvider.(v4a.CredentialsProvider)
if !ok {
return out, metadata, fmt.Errorf("invalid credential-provider provided for sigV4a Signer")
}
mw := v4a.NewPresignHTTPRequestMiddleware(v4a.PresignHTTPRequestMiddlewareOptions{
CredentialsProvider: v4aCredentialProvider,
Presigner: p.v4aSigner,
LogSigning: p.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
default:
mw := v4.NewPresignHTTPRequestMiddleware(v4.PresignHTTPRequestMiddlewareOptions{
CredentialsProvider: p.credentialsProvider,
Presigner: p.v4Signer,
LogSigning: p.logSigning,
})
return mw.HandleFinalize(ctx, in, next)
}
}
// RegisterPreSigningMiddleware registers the wrapper pre-signing middleware to the stack. If a pre-signing middleware is already
// present, this provided middleware will be swapped. Otherwise the middleware will be added at the tail of the
// finalize step.
func RegisterPreSigningMiddleware(stack *middleware.Stack, signingMiddleware *PresignHTTPRequestMiddleware) (err error) {
const signedID = "PresignHTTPRequest"
_, present := stack.Finalize.Get(signedID)
if present {
_, err = stack.Finalize.Swap(signedID, signingMiddleware)
} else {
err = stack.Finalize.Add(signingMiddleware, middleware.After)
}
return err
}

View File

@ -0,0 +1,318 @@
package customizations
import (
"context"
"fmt"
"github.com/aws/smithy-go/encoding/httpbinding"
"log"
"net/url"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/service/internal/s3shared"
internalendpoints "github.com/aws/aws-sdk-go-v2/service/s3/internal/endpoints"
)
// EndpointResolver interface for resolving service endpoints.
type EndpointResolver interface {
ResolveEndpoint(region string, options EndpointResolverOptions) (aws.Endpoint, error)
}
// EndpointResolverOptions is the service endpoint resolver options
type EndpointResolverOptions = internalendpoints.Options
// UpdateEndpointParameterAccessor represents accessor functions used by the middleware
type UpdateEndpointParameterAccessor struct {
// functional pointer to fetch bucket name from provided input.
// The function is intended to take an input value, and
// return a string pointer to value of string, and bool if
// input has no bucket member.
GetBucketFromInput func(interface{}) (*string, bool)
}
// UpdateEndpointOptions provides the options for the UpdateEndpoint middleware setup.
type UpdateEndpointOptions struct {
// Accessor are parameter accessors used by the middleware
Accessor UpdateEndpointParameterAccessor
// use path style
UsePathStyle bool
// use transfer acceleration
UseAccelerate bool
// indicates if an operation supports s3 transfer acceleration.
SupportsAccelerate bool
// use dualstack
UseDualstack bool
// use ARN region
UseARNRegion bool
// Indicates that the operation should target the s3-object-lambda endpoint.
// Used to direct operations that do not route based on an input ARN.
TargetS3ObjectLambda bool
// EndpointResolver used to resolve endpoints. This may be a custom endpoint resolver
EndpointResolver EndpointResolver
// EndpointResolverOptions used by endpoint resolver
EndpointResolverOptions EndpointResolverOptions
// DisableMultiRegionAccessPoints indicates multi-region access point support is disabled
DisableMultiRegionAccessPoints bool
}
// UpdateEndpoint adds the middleware to the middleware stack based on the UpdateEndpointOptions.
func UpdateEndpoint(stack *middleware.Stack, options UpdateEndpointOptions) (err error) {
// initial arn look up middleware
err = stack.Initialize.Add(&s3shared.ARNLookup{
GetARNValue: options.Accessor.GetBucketFromInput,
}, middleware.Before)
if err != nil {
return err
}
// process arn
err = stack.Serialize.Insert(&processARNResource{
UseARNRegion: options.UseARNRegion,
UseAccelerate: options.UseAccelerate,
UseDualstack: options.UseDualstack,
EndpointResolver: options.EndpointResolver,
EndpointResolverOptions: options.EndpointResolverOptions,
DisableMultiRegionAccessPoints: options.DisableMultiRegionAccessPoints,
}, "OperationSerializer", middleware.Before)
if err != nil {
return err
}
// process whether the operation requires the s3-object-lambda endpoint
// Occurs before operation serializer so that hostPrefix mutations
// can be handled correctly.
err = stack.Serialize.Insert(&s3ObjectLambdaEndpoint{
UseEndpoint: options.TargetS3ObjectLambda,
UseAccelerate: options.UseAccelerate,
UseDualstack: options.UseDualstack,
EndpointResolver: options.EndpointResolver,
EndpointResolverOptions: options.EndpointResolverOptions,
}, "OperationSerializer", middleware.Before)
if err != nil {
return err
}
// remove bucket arn middleware
err = stack.Serialize.Insert(&removeBucketFromPathMiddleware{}, "OperationSerializer", middleware.After)
if err != nil {
return err
}
// enable dual stack support
err = stack.Serialize.Insert(&s3shared.EnableDualstack{
UseDualstack: options.UseDualstack,
DefaultServiceID: "s3",
}, "OperationSerializer", middleware.After)
if err != nil {
return err
}
// update endpoint to use options for path style and accelerate
err = stack.Serialize.Insert(&updateEndpoint{
usePathStyle: options.UsePathStyle,
getBucketFromInput: options.Accessor.GetBucketFromInput,
useAccelerate: options.UseAccelerate,
supportsAccelerate: options.SupportsAccelerate,
}, (*s3shared.EnableDualstack)(nil).ID(), middleware.After)
if err != nil {
return err
}
return err
}
type updateEndpoint struct {
// path style options
usePathStyle bool
getBucketFromInput func(interface{}) (*string, bool)
// accelerate options
useAccelerate bool
supportsAccelerate bool
}
// ID returns the middleware ID.
func (*updateEndpoint) ID() string {
return "S3:UpdateEndpoint"
}
func (u *updateEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
// if arn was processed, skip this middleware
if _, ok := s3shared.GetARNResourceFromContext(ctx); ok {
return next.HandleSerialize(ctx, in)
}
// skip this customization if host name is set as immutable
if smithyhttp.GetHostnameImmutable(ctx) {
return next.HandleSerialize(ctx, in)
}
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// check if accelerate is supported
if u.useAccelerate && !u.supportsAccelerate {
// accelerate is not supported, thus will be ignored
log.Println("Transfer acceleration is not supported for the operation, ignoring UseAccelerate.")
u.useAccelerate = false
}
// transfer acceleration is not supported with path style urls
if u.useAccelerate && u.usePathStyle {
log.Println("UseAccelerate is not compatible with UsePathStyle, ignoring UsePathStyle.")
u.usePathStyle = false
}
if u.getBucketFromInput != nil {
// Below customization only apply if bucket name is provided
bucket, ok := u.getBucketFromInput(in.Parameters)
if ok && bucket != nil {
region := awsmiddleware.GetRegion(ctx)
if err := u.updateEndpointFromConfig(req, *bucket, region); err != nil {
return out, metadata, err
}
}
}
return next.HandleSerialize(ctx, in)
}
func (u updateEndpoint) updateEndpointFromConfig(req *smithyhttp.Request, bucket string, region string) error {
// do nothing if path style is enforced
if u.usePathStyle {
return nil
}
if !hostCompatibleBucketName(req.URL, bucket) {
// bucket name must be valid to put into the host for accelerate operations.
// For non-accelerate operations the bucket name can stay in the path if
// not valid hostname.
var err error
if u.useAccelerate {
err = fmt.Errorf("bucket name %s is not compatible with S3", bucket)
}
// No-Op if not using accelerate.
return err
}
// accelerate is only supported if use path style is disabled
if u.useAccelerate {
parts := strings.Split(req.URL.Host, ".")
if len(parts) < 3 {
return fmt.Errorf("unable to update endpoint host for S3 accelerate, hostname invalid, %s", req.URL.Host)
}
if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
parts[0] = "s3-accelerate"
}
for i := 1; i+1 < len(parts); i++ {
if strings.EqualFold(parts[i], region) {
parts = append(parts[:i], parts[i+1:]...)
break
}
}
// construct the url host
req.URL.Host = strings.Join(parts, ".")
}
// move bucket to follow virtual host style
moveBucketNameToHost(req.URL, bucket)
return nil
}
// updates endpoint to use virtual host styling
func moveBucketNameToHost(u *url.URL, bucket string) {
u.Host = bucket + "." + u.Host
removeBucketFromPath(u, bucket)
}
// remove bucket from url
func removeBucketFromPath(u *url.URL, bucket string) {
if strings.HasPrefix(u.Path, "/"+bucket) {
// modify url path
u.Path = strings.Replace(u.Path, "/"+bucket, "", 1)
// modify url raw path
u.RawPath = strings.Replace(u.RawPath, "/"+httpbinding.EscapePath(bucket, true), "", 1)
}
if u.Path == "" {
u.Path = "/"
}
if u.RawPath == "" {
u.RawPath = "/"
}
}
// hostCompatibleBucketName returns true if the request should
// put the bucket in the host. This is false if S3ForcePathStyle is
// explicitly set or if the bucket is not DNS compatible.
func hostCompatibleBucketName(u *url.URL, bucket string) bool {
// Bucket might be DNS compatible but dots in the hostname will fail
// certificate validation, so do not use host-style.
if u.Scheme == "https" && strings.Contains(bucket, ".") {
return false
}
// if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
// Buckets created outside of the classic region MUST be DNS compatible.
func dnsCompatibleBucketName(bucket string) bool {
if strings.Contains(bucket, "..") {
return false
}
// checks for `^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$` domain mapping
if !((bucket[0] > 96 && bucket[0] < 123) || (bucket[0] > 47 && bucket[0] < 58)) {
return false
}
for _, c := range bucket[1:] {
if !((c > 96 && c < 123) || (c > 47 && c < 58) || c == 46 || c == 45) {
return false
}
}
// checks for `^(\d+\.){3}\d+$` IPaddressing
v := strings.SplitN(bucket, ".", -1)
if len(v) == 4 {
for _, c := range bucket {
if !((c > 47 && c < 58) || c == 46) {
// we confirm that this is not a IP address
return true
}
}
// this is a IP address
return false
}
return true
}