robocar-steering-tflite-edg.../cmd/rc-steering/rc-steering_test.go

107 lines
2.9 KiB
Go

package main
import (
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
"testing"
)
func Test_parseModelName(t *testing.T) {
type args struct {
modelPath string
}
tests := []struct {
name string
args args
wantModelType tools.ModelType
wantImgWidth int
wantImgHeight int
wantHorizon int
wantErr bool
}{
{
name: "categorical",
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
wantModelType: tools.ModelTypeCategorical,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "linear",
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
wantModelType: tools.ModelTypeLinear,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "linear with prefix",
args: args{modelPath: "/tmp/model_angle_linear_120x160h10.tflite"},
wantModelType: tools.ModelTypeLinear,
wantImgWidth: 120,
wantImgHeight: 160,
wantHorizon: 10,
wantErr: false,
},
{
name: "bad-model",
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-width",
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-height",
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
{
name: "bad-horizon",
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
wantModelType: tools.ModelTypeUnknown,
wantImgWidth: 0,
wantImgHeight: 0,
wantHorizon: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
if (err != nil) != tt.wantErr {
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotModelType != tt.wantModelType {
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
}
if gotImgWidth != tt.wantImgWidth {
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
}
if gotImgHeight != tt.wantImgHeight {
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
}
if gotHorizon != tt.wantHorizon {
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
}
})
}
}