feat: add display record command and refactor training command
This commit is contained in:
		@@ -35,10 +35,10 @@ type Training struct {
 | 
			
		||||
	outputBucket string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgHeight, imgWidth int, sliceSize int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
 | 
			
		||||
func (t *Training) TrainDir(ctx context.Context, jobName, basedir string, imgWidth, imgHeight, sliceSize int, horizon int, withFlipImage bool, outputModelFile string, enableSpotTraining bool) error {
 | 
			
		||||
	l := zap.S()
 | 
			
		||||
	l.Infof("run training with data from %s", basedir)
 | 
			
		||||
	archive, err := data.BuildArchive(basedir, sliceSize, withFlipImage)
 | 
			
		||||
	archive, err := data.BuildArchive(basedir, sliceSize, imgWidth, imgHeight, horizon, withFlipImage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unable to build data archive: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -110,7 +110,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
 | 
			
		||||
			S3OutputPath: aws.String(t.outputBucket),
 | 
			
		||||
		},
 | 
			
		||||
		ResourceConfig: &types.ResourceConfig{
 | 
			
		||||
			InstanceCount:  1,
 | 
			
		||||
			InstanceCount: 1,
 | 
			
		||||
			//InstanceType:   types.TrainingInstanceTypeMlP2Xlarge,
 | 
			
		||||
			InstanceType:   types.TrainingInstanceTypeMlG4dnXlarge,
 | 
			
		||||
			VolumeSizeInGB: 1,
 | 
			
		||||
@@ -168,7 +168,7 @@ func (t *Training) runTraining(ctx context.Context, jobName string, slideSize in
 | 
			
		||||
		}
 | 
			
		||||
		switch status.TrainingJobStatus {
 | 
			
		||||
		case types.TrainingJobStatusInProgress:
 | 
			
		||||
			l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions) - 1].StatusMessage)
 | 
			
		||||
			l.Infof("job in progress: %v - %v - %v", status.TrainingJobStatus, status.SecondaryStatus, *status.SecondaryStatusTransitions[len(status.SecondaryStatusTransitions)-1].StatusMessage)
 | 
			
		||||
			continue
 | 
			
		||||
		case types.TrainingJobStatusFailed:
 | 
			
		||||
			return fmt.Errorf("job %s finished with status %v", jobName, status.TrainingJobStatus)
 | 
			
		||||
@@ -198,5 +198,5 @@ func ListJob(ctx context.Context) error {
 | 
			
		||||
	for _, job := range jobs.TrainingJobSummaries {
 | 
			
		||||
		fmt.Printf("%s\t\t%s\n", *job.TrainingJobName, job.TrainingJobStatus)
 | 
			
		||||
	}
 | 
			
		||||
	return  nil
 | 
			
		||||
}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user