feat(training): new features
* add flip-image option * add command to list models * add option to override image size when training is launched * add option to disable aws spot instance
This commit is contained in:
73
pkg/models/models.go
Normal file
73
pkg/models/models.go
Normal file
@ -0,0 +1,73 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/cyrilix/robocar-tools/pkg/awsutils"
|
||||
"go.uber.org/zap"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func ListModels(ctx context.Context, bucket string) error {
|
||||
|
||||
// Create an Amazon S3 service client
|
||||
client := s3.NewFromConfig(awsutils.MustLoadConfig())
|
||||
|
||||
// Get the first page of results for ListObjectsV2 for a bucket
|
||||
outputs, err := client.ListObjectsV2(
|
||||
ctx,
|
||||
&s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(bucket),
|
||||
Prefix: aws.String("output"),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to list models in bucket %v: %w", bucket, err)
|
||||
}
|
||||
for _, output := range outputs.Contents {
|
||||
fmt.Printf("model: %s\n", *output.Key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownloadArchive(ctx context.Context, bucketName, modelPath string) ([]byte, error) {
|
||||
l := zap.S().With(
|
||||
"bucket", bucketName,
|
||||
"model", modelPath,
|
||||
)
|
||||
client := s3.NewFromConfig(awsutils.MustLoadConfig())
|
||||
|
||||
l.Debug("download model")
|
||||
archive, err := client.GetObject(
|
||||
ctx,
|
||||
|
||||
&s3.GetObjectInput{
|
||||
Bucket: aws.String(bucketName),
|
||||
Key: aws.String(modelPath),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to download model: %w", err)
|
||||
}
|
||||
l.Debug("model downloaded")
|
||||
resp, err := ioutil.ReadAll(archive.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read model archive content: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DownloadArchiveToFile(ctx context.Context, bucketName, modelPath, outputFile string) error {
|
||||
arch, err := DownloadArchive(ctx, bucketName, modelPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to download model '%v/%v': %w", bucketName, modelPath, err)
|
||||
}
|
||||
err = ioutil.WriteFile(outputFile, arch, os.FileMode(0755))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write model '%s' to file '%s': %v", modelPath, outputFile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user