feat: add display record command and refactor training command

This commit is contained in:
2022-06-09 12:19:54 +02:00
parent 02db9e241e
commit 8b8d53af58
8 changed files with 255 additions and 47 deletions

View File

@ -35,10 +35,10 @@ type Training struct {
outputBucket string
}
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgHeight, imgWidth int, sliceSize int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
l := zap.S()
l.Infof("run training with data from %s", basedir)
archive, err := data.BuildArchive(basedir, sliceSize, withFlipImage)
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
if err != nil {
return fmt.Errorf("unable to build data archive: %w", err)
}
@ -110,7 +110,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
S3OutputPath: aws.String(t.outputBucket),
},
ResourceConfig: &types.ResourceConfig{
InstanceCount: 1,
InstanceCount: 1,
//InstanceType: types.TrainingInstanceTypeMlP2Xlarge,
InstanceType: types.TrainingInstanceTypeMlG4dnXlarge,
VolumeSizeInGB: 1,
@ -168,7 +168,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
}
switch status.TrainingJobStatus {
case types.TrainingJobStatusInProgress:
l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions) - 1].StatusMessage)
l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions)-1].StatusMessage)
continue
case types.TrainingJobStatusFailed:
return fmt.Errorf("job %s finished with status %v", jobName, status.TrainingJobStatus)
@ -198,5 +198,5 @@ func ListJob(ctx context.Context) error {
for _, job := range jobs.TrainingJobSummaries {
fmt.Printf("%s\t\t%s\n", *job.TrainingJobName, job.TrainingJobStatus)
}
return nil
}
return nil
}