feat: add display record command and refactor training command
This commit is contained in:
@ -35,10 +35,10 @@ type Training struct {
|
||||
outputBucket string
|
||||
}
|
||||
|
||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgHeight, imgWidth int, sliceSize int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
|
||||
l := zap.S()
|
||||
l.Infof("run training with data from %s", basedir)
|
||||
archive, err := data.BuildArchive(basedir, sliceSize, withFlipImage)
|
||||
archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build data archive: %w", err)
|
||||
}
|
||||
@ -110,7 +110,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
|
||||
S3OutputPath: aws.String(t.outputBucket),
|
||||
},
|
||||
ResourceConfig: &types.ResourceConfig{
|
||||
InstanceCount: 1,
|
||||
InstanceCount: 1,
|
||||
//InstanceType: types.TrainingInstanceTypeMlP2Xlarge,
|
||||
InstanceType: types.TrainingInstanceTypeMlG4dnXlarge,
|
||||
VolumeSizeInGB: 1,
|
||||
@ -168,7 +168,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
|
||||
}
|
||||
switch status.TrainingJobStatus {
|
||||
case types.TrainingJobStatusInProgress:
|
||||
l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions) - 1].StatusMessage)
|
||||
l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions)-1].StatusMessage)
|
||||
continue
|
||||
case types.TrainingJobStatusFailed:
|
||||
return fmt.Errorf("job %s finished with status %v", jobName, status.TrainingJobStatus)
|
||||
@ -198,5 +198,5 @@ func ListJob(ctx context.Context) error {
|
||||
for _, job := range jobs.TrainingJobSummaries {
|
||||
fmt.Printf("%s\t\t%s\n", *job.TrainingJobName, job.TrainingJobStatus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user