feat: load model from oci image
This commit is contained in:
		@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/cyrilix/robocar-base/cli"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/oci"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
 | 
			
		||||
	"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
@@ -26,7 +27,7 @@ var (
 | 
			
		||||
func main() {
 | 
			
		||||
	var mqttBroker, username, password, clientId string
 | 
			
		||||
	var cameraTopic, steeringTopic string
 | 
			
		||||
	var modelPath string
 | 
			
		||||
	var modelPath, modelsDir, ociRef string
 | 
			
		||||
	var edgeVerbosity int
 | 
			
		||||
	var imgWidth, imgHeight, horizon int
 | 
			
		||||
 | 
			
		||||
@@ -36,6 +37,8 @@ func main() {
 | 
			
		||||
	cli.InitMqttFlags(DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
 | 
			
		||||
 | 
			
		||||
	flag.StringVar(&modelPath, "model", "", "path to model file")
 | 
			
		||||
	flag.StringVar(&ociRef, "oci-model", "", "oci image to pull")
 | 
			
		||||
	flag.StringVar(&modelsDir, "models-dir", "/tmp/robocar/models", "path where to store model file")
 | 
			
		||||
	flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
 | 
			
		||||
	flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
 | 
			
		||||
	flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
 | 
			
		||||
@@ -65,16 +68,33 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	cleanup := metrics.Init(context.Background())
 | 
			
		||||
	defer cleanup()
 | 
			
		||||
 | 
			
		||||
	if modelPath == "" {
 | 
			
		||||
		zap.L().Error("model path is mandatory")
 | 
			
		||||
	if modelPath == "" && ociRef == "" {
 | 
			
		||||
		zap.L().Error("model path or oci image is mandatory")
 | 
			
		||||
		flag.PrintDefaults()
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
	modelType, width, height, horizonFromName, err := parseModelName(modelPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
 | 
			
		||||
	if modelPath != "" && ociRef != "" {
 | 
			
		||||
		zap.L().Error("model path and oci image are exclusives")
 | 
			
		||||
		flag.PrintDefaults()
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var modelType tools.ModelType
 | 
			
		||||
	var width, height, horizonFromName int
 | 
			
		||||
 | 
			
		||||
	if modelPath != "" {
 | 
			
		||||
		modelType, width, height, horizonFromName, err = parseModelName(modelPath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		modelPath, modelType, width, height, horizonFromName, err = oci.PullOciImage(ociRef, modelsDir)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if imgWidth == 0 {
 | 
			
		||||
		imgWidth = width
 | 
			
		||||
	}
 | 
			
		||||
@@ -90,6 +110,11 @@ func main() {
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ociRef == "" {
 | 
			
		||||
		zap.S().Infof("model path            : %v", modelPath)
 | 
			
		||||
	} else {
 | 
			
		||||
		zap.S().Infof("oci image model       : %v", ociRef)
 | 
			
		||||
	}
 | 
			
		||||
	zap.S().Infof("model type            : %v", modelType)
 | 
			
		||||
	zap.S().Infof("model for image width : %v", imgWidth)
 | 
			
		||||
	zap.S().Infof("model for image height: %v", imgHeight)
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,15 @@ func Test_parseModelName(t *testing.T) {
 | 
			
		||||
			wantHorizon:   10,
 | 
			
		||||
			wantErr:       false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "linear with prefix",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_angle_linear_120x160h10.tflite"},
 | 
			
		||||
			wantModelType: tools.ModelTypeLinear,
 | 
			
		||||
			wantImgWidth:  120,
 | 
			
		||||
			wantImgHeight: 160,
 | 
			
		||||
			wantHorizon:   10,
 | 
			
		||||
			wantErr:       false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "bad-model",
 | 
			
		||||
			args:          args{modelPath: "/tmp/model_123_120x160h10.tflite"},
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user