feat(training): add flag toconfigure model type training
This commit is contained in:
		@@ -12,9 +12,40 @@ import (
 | 
			
		||||
	"github.com/cyrilix/robocar-tools/pkg/models"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ModelType int
 | 
			
		||||
 | 
			
		||||
func ParseModelType(s string) ModelType {
 | 
			
		||||
	switch strings.ToLower(s) {
 | 
			
		||||
	case "categorical":
 | 
			
		||||
		return ModelTypeCategorical
 | 
			
		||||
	case "linear":
 | 
			
		||||
		return ModelTypeLinear
 | 
			
		||||
	default:
 | 
			
		||||
		return ModelTypeUnknown
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m ModelType) String() string {
 | 
			
		||||
	switch m {
 | 
			
		||||
	case ModelTypeCategorical:
 | 
			
		||||
		return "categorical"
 | 
			
		||||
	case ModelTypeLinear:
 | 
			
		||||
		return "linear"
 | 
			
		||||
	default:
 | 
			
		||||
		return "unknown"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ModelTypeUnknown ModelType = iota
 | 
			
		||||
	ModelTypeCategorical
 | 
			
		||||
	ModelTypeLinear
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func New(bucketName string, ociImage, roleArn string) *Training {
 | 
			
		||||
	return &Training{
 | 
			
		||||
		config:       awsutils.MustLoadConfig(),
 | 
			
		||||
@@ -35,7 +66,7 @@ type Training struct {
 | 
			
		||||
	outputBucket string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
 | 
			
		||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, modelType ModelType, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
 | 
			
		||||
	l := zap.S()
 | 
			
		||||
	l.Infof("run training with data from %s", basedir)
 | 
			
		||||
	archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
 | 
			
		||||
@@ -50,14 +81,7 @@ func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWid
 | 
			
		||||
	}
 | 
			
		||||
	l.Info("")
 | 
			
		||||
 | 
			
		||||
	err = t.runTraining(
 | 
			
		||||
		ctx,
 | 
			
		||||
		jobName,
 | 
			
		||||
		sliceSize,
 | 
			
		||||
		imgHeight,
 | 
			
		||||
		imgWidth,
 | 
			
		||||
		enableSpotTraining,
 | 
			
		||||
	)
 | 
			
		||||
	err = t.runTraining(ctx, jobName, sliceSize, imgHeight, imgWidth, horizon, enableSpotTraining, modelType)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unable to run training: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -95,7 +119,7 @@ func List(bucketName string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize int, imgHeight, imgWidth int, enableSpotTraining bool) error {
 | 
			
		||||
func (t *Training) runTraining(ctx context.Context, jobName string, slideSize, imgHeight, imgWidth, horizon int, enableSpotTraining bool, modelType ModelType) error {
 | 
			
		||||
	l := zap.S()
 | 
			
		||||
	client := sagemaker.NewFromConfig(awsutils.MustLoadConfig())
 | 
			
		||||
	l.Infof("Start training job '%s'", jobName)
 | 
			
		||||
@@ -125,6 +149,9 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
 | 
			
		||||
			"slide_size":       strconv.Itoa(slideSize),
 | 
			
		||||
			"img_height":       strconv.Itoa(imgHeight),
 | 
			
		||||
			"img_width":        strconv.Itoa(imgWidth),
 | 
			
		||||
			"batch_size":       strconv.Itoa(32),
 | 
			
		||||
			"model_type":       modelType.String(),
 | 
			
		||||
			"horizon":          strconv.Itoa(horizon),
 | 
			
		||||
		},
 | 
			
		||||
		InputDataConfig: []types.Channel{
 | 
			
		||||
			{
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user