Compare commits
6 Commits
chore/upgr
...
v0.5.3
Author | SHA1 | Date | |
---|---|---|---|
b57698380e | |||
c1221d2e0e | |||
24d39461b4 | |||
2e8bf4252b | |||
bc90193edf | |||
acf6066fea |
@ -4,8 +4,8 @@ IMAGE_NAME=robocar-steering-tflite-edgetpu
|
||||
TAG=$(git describe)
|
||||
FULL_IMAGE_NAME=docker.io/cyrilix/${IMAGE_NAME}:${TAG}
|
||||
BINARY=rc-steering
|
||||
TFLITE_VERSION=2.6.0
|
||||
GOLANG_VERSION=1.18
|
||||
TFLITE_VERSION=2.10.0
|
||||
GOLANG_VERSION=1.19
|
||||
|
||||
GOTAGS="-tags netgo"
|
||||
BUILDER_CONTAINER="${IMAGE_NAME}-builder"
|
||||
@ -35,6 +35,20 @@ image_build_binaries(){
|
||||
buildah run $containerName ln -s /usr/lib/aarch64-linux-gnu/libedgetpu.so.1 /usr/lib/aarch64-linux-gnu/libedgetpu.so
|
||||
|
||||
printf "Compile for linux/amd64\n"
|
||||
LIB_ARCH=x86_64-linux-gnu
|
||||
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
|
||||
buildah run \
|
||||
--env CGO_ENABLED=1 \
|
||||
--env CC=gcc \
|
||||
@ -43,25 +57,27 @@ image_build_binaries(){
|
||||
--env GOARCH=amd64 \
|
||||
--env GOARM=${GOARM} \
|
||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
||||
--env CGO_LDFLAGS="-L /usr/local/lib/x86_64-linux-gnu -L /usr/lib/x86_64-linux-gnu" \
|
||||
--env CGO_LDFLAGS="${LIB_FLAGS}" \
|
||||
$containerName \
|
||||
go build -a -o rc-steering.amd64 ./cmd/rc-steering
|
||||
#--env CGO_CXXFLAGS="--std=c++1z" \
|
||||
|
||||
printf "Compile for linux/arm/v7\n"
|
||||
buildah run \
|
||||
--env CGO_ENABLED=1 \
|
||||
--env CC=arm-linux-gnueabihf-gcc \
|
||||
--env CXX=arm-linux-gnueabihf-g++ \
|
||||
--env GOOS=linux \
|
||||
--env GOARCH=arm \
|
||||
--env GOARM=7 \
|
||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
||||
--env CGO_LDFLAGS="-L /usr/lib/arm-linux-gnueabihf -L /usr/local/lib/arm-linux-gnueabihf" \
|
||||
$containerName \
|
||||
go build -a -o rc-steering.armhf ./cmd/rc-steering
|
||||
|
||||
printf "Compile for linux/arm64\n"
|
||||
LIB_ARCH=aarch64-linux-gnu
|
||||
LIB_FLAGS="-L /usr/local/lib/${LIB_ARCH} \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/base -labsl_base -labsl_throw_delegate -labsl_raw_logging_internal -labsl_spinlock_wait -labsl_malloc_internal -labsl_log_severity -labsl_strerror\
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/status -labsl_status \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/hash -labsl_hash -labsl_city -labsl_low_level_hash \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/flags -labsl_flags -labsl_flags_internal -labsl_flags_marshalling -labsl_flags_reflection -labsl_flags_config -labsl_flags_program_name -labsl_flags_private_handle_accessor -labsl_flags_commandlineflag -labsl_flags_commandlineflag_internal\
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/types -labsl_bad_variant_access -labsl_bad_optional_access -labsl_bad_any_cast_impl \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/strings -labsl_strings -labsl_str_format_internal -labsl_cord -labsl_cordz_info -labsl_cord_internal -labsl_cordz_functions -labsl_cordz_handle -labsl_strings_internal \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/time -labsl_time -labsl_time_zone -labsl_civil_time \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/numeric -labsl_int128 \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/synchronization -labsl_synchronization -labsl_graphcycles_internal\
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/debugging -labsl_stacktrace -labsl_symbolize -labsl_debugging_internal -labsl_demangle_internal \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/profiling -labsl_exponential_biased \
|
||||
-L/usr/local/lib/${LIB_ARCH}/absl/container -labsl_raw_hash_set -labsl_hashtablez_sampler"
|
||||
buildah run \
|
||||
--env CGO_ENABLED=1 \
|
||||
--env CC=aarch64-linux-gnu-gcc \
|
||||
@ -69,7 +85,7 @@ image_build_binaries(){
|
||||
--env GOOS=linux \
|
||||
--env GOARCH=arm64 \
|
||||
--env CGO_CPPFLAGS="-I/usr/local/include" \
|
||||
--env CGO_LDFLAGS="-L /usr/lib/aarch64-linux-gnu -L /usr/local/lib/aarch64-linux-gnu" \
|
||||
--env CGO_LDFLAGS="${LIB_FLAGS}" \
|
||||
$containerName \
|
||||
go build -a -o rc-steering.arm64 ./cmd/rc-steering
|
||||
}
|
||||
@ -110,7 +126,6 @@ image_build_binaries
|
||||
|
||||
image_build linux/amd64
|
||||
image_build linux/arm64
|
||||
image_build linux/arm/v7
|
||||
|
||||
|
||||
# push image
|
||||
|
@ -3,18 +3,26 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/cyrilix/robocar-base/cli"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/metrics"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/steering"
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"go.uber.org/zap"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultClientId = "robocar-steering-tflite-edgetpu"
|
||||
)
|
||||
|
||||
var (
|
||||
modelNameRegex = regexp.MustCompile(".*model_(?P<type>(categorical)|(linear))_(?P<imgWidth>\\d+)x(?P<imgHeight>\\d+)h(?P<horizon>\\d+)_edgetpu.tflite$")
|
||||
)
|
||||
|
||||
func main() {
|
||||
var mqttBroker, username, password, clientId string
|
||||
var cameraTopic, steeringTopic string
|
||||
@ -31,8 +39,8 @@ func main() {
|
||||
flag.StringVar(&steeringTopic, "mqtt-topic-road", os.Getenv("MQTT_TOPIC_STEERING"), "Mqtt topic to publish road detection result, use MQTT_TOPIC_STEERING if args not set")
|
||||
flag.StringVar(&cameraTopic, "mqtt-topic-camera", os.Getenv("MQTT_TOPIC_CAMERA"), "Mqtt topic that contains camera frame values, use MQTT_TOPIC_CAMERA if args not set")
|
||||
flag.IntVar(&edgeVerbosity, "edge-verbosity", 0, "Edge TPU Verbosity")
|
||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model (mandatory)")
|
||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model (mandatory)")
|
||||
flag.IntVar(&imgWidth, "img-width", 0, "image width expected by model")
|
||||
flag.IntVar(&imgHeight, "img-height", 0, "image height expected by model")
|
||||
flag.IntVar(&horizon, "horizon", 0, "upper zone to crop from image. Models expect size 'imgHeight - horizon'")
|
||||
logLevel := zap.LevelFlag("log", zap.InfoLevel, "log level")
|
||||
flag.Parse()
|
||||
@ -63,19 +71,37 @@ func main() {
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
modelType, width, height, horizonFromName, err := parseModelName(modelPath)
|
||||
if err != nil {
|
||||
zap.S().Panicf("bad model name '%v', unable to detect configuration from name pattern: %v", modelPath, err)
|
||||
}
|
||||
if imgWidth == 0 {
|
||||
imgWidth = width
|
||||
}
|
||||
if imgHeight == 0 {
|
||||
imgHeight = height
|
||||
}
|
||||
if horizonFromName == 0 {
|
||||
horizon = horizonFromName
|
||||
}
|
||||
if imgWidth <= 0 || imgHeight <= 0 {
|
||||
zap.L().Error("img-width and img-height are mandatory")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
zap.S().Infof("model type : %v", modelType)
|
||||
zap.S().Infof("model for image width : %v", imgWidth)
|
||||
zap.S().Infof("model for image height: %v", imgHeight)
|
||||
zap.S().Infof("model with horizon : %v", horizon)
|
||||
|
||||
client, err := cli.Connect(mqttBroker, username, password, clientId)
|
||||
if err != nil {
|
||||
zap.L().Fatal("unable to connect to mqtt bus", zap.Error(err))
|
||||
}
|
||||
defer client.Disconnect(50)
|
||||
|
||||
p := steering.NewPart(client, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||
p := steering.NewPart(client, modelType, modelPath, steeringTopic, cameraTopic, edgeVerbosity, imgWidth, imgHeight, horizon)
|
||||
defer p.Stop()
|
||||
|
||||
cli.HandleExit(p)
|
||||
@ -85,3 +111,33 @@ func main() {
|
||||
zap.L().Fatal("unable to start service", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func parseModelName(modelPath string) (modelType tools.ModelType, imgWidth, imgHeight int, horizon int, err error) {
|
||||
match := modelNameRegex.FindStringSubmatch(modelPath)
|
||||
|
||||
results := map[string]string{}
|
||||
for i, name := range match {
|
||||
results[modelNameRegex.SubexpNames()[i]] = name
|
||||
}
|
||||
modelType = tools.ParseModelType(results["type"])
|
||||
if modelType == tools.ModelTypeUnknown {
|
||||
err = fmt.Errorf("unknown model type '%v'", results["type"])
|
||||
return
|
||||
}
|
||||
imgWidth, err = strconv.Atoi(results["imgWidth"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert image width '%v' to integer: %v", results["imgWidth"], err)
|
||||
return
|
||||
}
|
||||
imgHeight, err = strconv.Atoi(results["imgHeight"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert image height '%v' to integer: %v", results["imgHeight"], err)
|
||||
return
|
||||
}
|
||||
horizon, err = strconv.Atoi(results["horizon"])
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to convert horizon '%v' to integer: %v", results["horizon"], err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
97
cmd/rc-steering/rc-steering_test.go
Normal file
97
cmd/rc-steering/rc-steering_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/cyrilix/robocar-steering-tflite-edgetpu/pkg/tools"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_parseModelName(t *testing.T) {
|
||||
type args struct {
|
||||
modelPath string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantModelType tools.ModelType
|
||||
wantImgWidth int
|
||||
wantImgHeight int
|
||||
wantHorizon int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "categorical",
|
||||
args: args{modelPath: "/tmp/model_categorical_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeCategorical,
|
||||
wantImgWidth: 120,
|
||||
wantImgHeight: 160,
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "linear",
|
||||
args: args{modelPath: "/tmp/model_linear_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeLinear,
|
||||
wantImgWidth: 120,
|
||||
wantImgHeight: 160,
|
||||
wantHorizon: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-model",
|
||||
args: args{modelPath: "/tmp/model_123_120x160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-width",
|
||||
args: args{modelPath: "/tmp/model_categorical_ax160h10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-height",
|
||||
args: args{modelPath: "/tmp/model_categorical_120xh10.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-horizon",
|
||||
args: args{modelPath: "/tmp/model_categorical_120x160h.tflite"},
|
||||
wantModelType: tools.ModelTypeUnknown,
|
||||
wantImgWidth: 0,
|
||||
wantImgHeight: 0,
|
||||
wantHorizon: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotModelType, gotImgWidth, gotImgHeight, gotHorizon, err := parseModelName(tt.args.modelPath)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseModelName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotModelType != tt.wantModelType {
|
||||
t.Errorf("parseModelName() gotModelType = %v, want %v", gotModelType, tt.wantModelType)
|
||||
}
|
||||
if gotImgWidth != tt.wantImgWidth {
|
||||
t.Errorf("parseModelName() gotImgWidth = %v, want %v", gotImgWidth, tt.wantImgWidth)
|
||||
}
|
||||
if gotImgHeight != tt.wantImgHeight {
|
||||
t.Errorf("parseModelName() gotImgHeight = %v, want %v", gotImgHeight, tt.wantImgHeight)
|
||||
}
|
||||
if gotHorizon != tt.wantHorizon {
|
||||
t.Errorf("parseModelName() gotHorizon = %v, want %v", gotHorizon, tt.wantHorizon)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
4
go.mod
4
go.mod
@ -1,13 +1,13 @@
|
||||
module github.com/cyrilix/robocar-steering-tflite-edgetpu
|
||||
|
||||
go 1.18
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/cyrilix/robocar-base v0.1.7
|
||||
github.com/cyrilix/robocar-protobuf/go v1.0.5
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/eclipse/paho.mqtt.golang v1.4.1
|
||||
github.com/mattn/go-tflite v1.0.2
|
||||
github.com/mattn/go-tflite v1.0.4
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v0.30.0
|
||||
go.opentelemetry.io/otel/metric v0.30.0
|
||||
go.opentelemetry.io/otel/sdk/metric v0.30.0
|
||||
|
5
go.sum
5
go.sum
@ -1,7 +1,6 @@
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
|
||||
github.com/cyrilix/robocar-base v0.1.7 h1:EVzZ0KjigSFpke5f3A/PybEH3WFUEIrYSc3z/dhOZ48=
|
||||
github.com/cyrilix/robocar-base v0.1.7/go.mod h1:4E11HQSNy2NT8e7MW188y6ST9C0RzarKyn7sK/3V/Lk=
|
||||
github.com/cyrilix/robocar-protobuf/go v1.0.5 h1:PX1At+pf6G7gJwT4LzJLQu3/LPFTTNNlZmZSYtnSELY=
|
||||
@ -31,8 +30,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-tflite v1.0.2 h1:P9CKqjyRSRM31SfL65WklD8U5B/iPD4CJQiRkB8K02g=
|
||||
github.com/mattn/go-tflite v1.0.2/go.mod h1:2NwhEYXoP8vxRIpu95DElqMkZoV39ABRPF3AETN7N1w=
|
||||
github.com/mattn/go-tflite v1.0.4 h1:wpfNKjMr3IJz4xI+oUeHE70RU6Q5dZc0FK/X8vCWLAo=
|
||||
github.com/mattn/go-tflite v1.0.4/go.mod h1:j7bVlVHgKURK0p7AQOw3OqlGE2SVXqck7JsJo4wI+bc=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
@ -19,9 +19,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewPart(client mqtt.Client, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
func NewPart(client mqtt.Client, modelType tools.ModelType, modelPath, steeringTopic, cameraTopic string, edgeVerbosity int, imgWidth, imgHeight, horizon int) *Part {
|
||||
return &Part{
|
||||
client: client,
|
||||
modelType: modelType,
|
||||
modelPath: modelPath,
|
||||
steeringTopic: steeringTopic,
|
||||
cameraTopic: cameraTopic,
|
||||
@ -42,6 +43,7 @@ type Part struct {
|
||||
|
||||
options *tflite.InterpreterOptions
|
||||
interpreter *tflite.Interpreter
|
||||
modelType tools.ModelType
|
||||
modelPath string
|
||||
model *tflite.Model
|
||||
edgeVebosity int
|
||||
@ -214,7 +216,14 @@ func (p *Part) Value(img image.Image) (float32, float32, error) {
|
||||
output := p.interpreter.GetOutputTensor(0).UInt8s()
|
||||
zap.L().Debug("raw steering", zap.Uint8s("result", output))
|
||||
|
||||
steering, score := tools.LinearBin(output, 15, -1, 2.0)
|
||||
var steering, score float64
|
||||
switch p.modelType {
|
||||
case tools.ModelTypeCategorical:
|
||||
steering, score = tools.LinearBin(output, 15, -1, 2.0)
|
||||
case tools.ModelTypeLinear:
|
||||
steering = 2*(float64(output[0])/255.) - 1.
|
||||
score = 0.6
|
||||
}
|
||||
zap.L().Debug("found steering",
|
||||
zap.Float64("steering", steering),
|
||||
zap.Float64("score", score),
|
||||
|
@ -1,9 +1,39 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ModelType int
|
||||
|
||||
func ParseModelType(s string) ModelType {
|
||||
switch strings.ToLower(s) {
|
||||
case "categorical":
|
||||
return ModelTypeCategorical
|
||||
case "linear":
|
||||
return ModelTypeLinear
|
||||
default:
|
||||
return ModelTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (m ModelType) String() string {
|
||||
switch m {
|
||||
case ModelTypeCategorical:
|
||||
return "categorical"
|
||||
case ModelTypeLinear:
|
||||
return "linear"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ModelTypeUnknown ModelType = iota
|
||||
ModelTypeCategorical
|
||||
ModelTypeLinear
|
||||
)
|
||||
|
||||
// LinearBin perform inverse linear_bin, taking
|
||||
@ -15,19 +45,11 @@ func LinearBin(arr []uint8, n int, offset int, r float64) (float64, float64) {
|
||||
}
|
||||
|
||||
var results []result
|
||||
minScore := 0.2
|
||||
for i := 0; i < outputSize; i++ {
|
||||
score := float64(int(arr[i])) / 255.0
|
||||
if score < minScore {
|
||||
continue
|
||||
}
|
||||
results = append(results, result{score: score, index: i})
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
|
||||
return 0., 0.
|
||||
}
|
||||
zap.S().Debugf("raw result: %v", results)
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
|
@ -75,3 +75,28 @@ func Test_LinearBin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelType(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want ModelType
|
||||
}{
|
||||
{name: "categorical", args: args{s: "categorical"}, want: ModelTypeCategorical},
|
||||
{name: "categorical-upper", args: args{s: "caTeGorical"}, want: ModelTypeCategorical},
|
||||
{name: "linear", args: args{s: "linear"}, want: ModelTypeLinear},
|
||||
{name: "linear-upper", args: args{s: "LineAr"}, want: ModelTypeLinear},
|
||||
{name: "unknown", args: args{s: "1234"}, want: ModelTypeUnknown},
|
||||
{name: "empty", args: args{s: ""}, want: ModelTypeUnknown},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ParseModelType(tt.args.s); got != tt.want {
|
||||
t.Errorf("ParseModelType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
2
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
2
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
@ -92,5 +92,5 @@ rm -Rf edgetpu
|
||||
MIT
|
||||
|
||||
## Author
|
||||
Yasuhrio Matsumoto (a.k.a. mattn)
|
||||
Yasuhiro Matsumoto (a.k.a. mattn)
|
||||
|
||||
|
1
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
1
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
@ -32,7 +32,6 @@ _make_registration(void* o_init, void* o_free, void* o_prepare, void* o_invoke,
|
||||
}
|
||||
|
||||
static void look_context(TfLiteContext *context) {
|
||||
context->tensors;
|
||||
TfLiteIntArray *plan = NULL;
|
||||
context->GetExecutionPlan(context, &plan);
|
||||
if (plan == NULL) return;
|
||||
|
2
vendor/modules.txt
vendored
2
vendor/modules.txt
vendored
@ -28,7 +28,7 @@ github.com/gorilla/websocket
|
||||
# github.com/mattn/go-pointer v0.0.1
|
||||
## explicit
|
||||
github.com/mattn/go-pointer
|
||||
# github.com/mattn/go-tflite v1.0.2
|
||||
# github.com/mattn/go-tflite v1.0.4
|
||||
## explicit; go 1.13
|
||||
github.com/mattn/go-tflite
|
||||
github.com/mattn/go-tflite/delegates
|
||||
|
Reference in New Issue
Block a user