feat: load model from oci image
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/cyrilix/robocar-base/cli"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/oci"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"go.uber.org/zap"
|
||||
@ -26,7 +27,7 @@ var (
|
||||
func main() {
|
||||
var mqttBroker, username, password, clientId string
|
||||
var cameraTopic, steeringTopic string
|
||||
var modelPath string
|
||||
var modelPath, modelsDir, ociRef string
|
||||
var edgeVerbosity int
|
||||
var imgWidth, imgHeight, horizon int
|
||||
|
||||
@ -36,6 +37,8 @@ func main() {
|
||||
cli.InitMqttFlags(DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
|
||||
|
||||
flag.StringVar(&modelPath, "model", "", "path to model file")
|
||||
flag.StringVar(&ociRef, "oci-model", "", "oci image to pull")
|
||||
flag.StringVar(&modelsDir, "models-dir", "/tmp/robocar/models", "path where to store model file")
|
||||
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
||||
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
||||
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
||||
@ -65,16 +68,33 @@ func main() {
|
||||
|
||||
cleanup := metrics.Init(context.Background())
|
||||
defer cleanup()
|
||||
|
||||
if modelPath == "" {
|
||||
zap.L().Error("model path is mandatory")
|
||||
if modelPath == "" && ociRef == "" {
|
||||
zap.L().Error("model path or oci image is mandatory")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
|
||||
if err != nil {
|
||||
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||
if modelPath != "" && ociRef != "" {
|
||||
zap.L().Error("model path and oci image are exclusives")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var modelType tools.ModelType
|
||||
var width, height, horizonFromName int
|
||||
|
||||
if modelPath != "" {
|
||||
modelType, width, height, horizonFromName, err = parseModelName(modelPath)
|
||||
if err != nil {
|
||||
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||
}
|
||||
} else {
|
||||
modelPath, modelType, width, height, horizonFromName, err = oci.PullOciImage(ociRef, modelsDir)
|
||||
if err != nil {
|
||||
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if imgWidth == 0 {
|
||||
imgWidth = width
|
||||
}
|
||||
@ -90,6 +110,11 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if ociRef == "" {
|
||||
zap.S().Infof("model path : %v", modelPath)
|
||||
} else {
|
||||
zap.S().Infof("oci image model : %v", ociRef)
|
||||
}
|
||||
zap.S().Infof("model type : %v", modelType)
|
||||
zap.S().Infof("model for image width : %v", imgWidth)
|
||||
zap.S().Infof("model for image height: %v", imgHeight)
|
||||
|
@ -36,6 +36,15 @@ func Test_parseModelName(t *testing.T) {
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "linear with prefix",
|
||||
args: args{modelPath: "/tmp/model_angle_linear_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeLinear,
|
||||
wantImgWidth: 120,
|
||||
wantImgHeight: 160,
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-model",
|
||||
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
|
||||
|
Reference in New Issue
Block a user