fix: bad inference result conversion
This commit is contained in:
		
							
								
								
									
										40
									
								
								pkg/tools/tools.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								pkg/tools/tools.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
package tools
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"sort"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// LinearBin  perform inverse linear_bin, taking
 | 
			
		||||
func LinearBin(arr []byte, n int, offset int, r float64) (float64, float64) {
 | 
			
		||||
	outputSize := len(arr)
 | 
			
		||||
	type result struct {
 | 
			
		||||
		score float64
 | 
			
		||||
		index int
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var results []result
 | 
			
		||||
	minScore := 0.2
 | 
			
		||||
	for i := 0; i < outputSize; i++ {
 | 
			
		||||
		score := float64(arr[i]) / 255.0
 | 
			
		||||
		if score < minScore {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		results = append(results, result{score: score, index: i})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(results) == 0 {
 | 
			
		||||
		zap.L().Warn(fmt.Sprintf("none steering with score > %0.2f found", minScore))
 | 
			
		||||
		return 0., 0.
 | 
			
		||||
	}
 | 
			
		||||
	zap.S().Debugf("raw result: %v", results)
 | 
			
		||||
 | 
			
		||||
	sort.Slice(results, func(i, j int) bool {
 | 
			
		||||
		return results[i].score > results[j].score
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	b := results[0].index
 | 
			
		||||
	a := float64(b)*(r/(float64(n)+float64(offset))) + float64(offset)
 | 
			
		||||
	return a, results[0].score
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										77
									
								
								pkg/tools/tools_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								pkg/tools/tools_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
			
		||||
package tools
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_LinearBin(t *testing.T) {
 | 
			
		||||
	type args struct {
 | 
			
		||||
		arr    []byte
 | 
			
		||||
		n      int
 | 
			
		||||
		offset int
 | 
			
		||||
		r      float64
 | 
			
		||||
	}
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		args  args
 | 
			
		||||
		want  float64
 | 
			
		||||
		want1 float64
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "default",
 | 
			
		||||
			args: args{
 | 
			
		||||
				arr:    []byte{0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0},
 | 
			
		||||
				n:      15,
 | 
			
		||||
				offset: -1,
 | 
			
		||||
				r:      2.0,
 | 
			
		||||
			},
 | 
			
		||||
			want:  0.,
 | 
			
		||||
			want1: 1.0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "left",
 | 
			
		||||
			args: args{
 | 
			
		||||
				arr:    []byte{255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
 | 
			
		||||
				n:      15,
 | 
			
		||||
				offset: -1,
 | 
			
		||||
				r:      2.0,
 | 
			
		||||
			},
 | 
			
		||||
			want:  -1.,
 | 
			
		||||
			want1: 1.0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "right",
 | 
			
		||||
			args: args{
 | 
			
		||||
				arr:    []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255},
 | 
			
		||||
				n:      15,
 | 
			
		||||
				offset: -1,
 | 
			
		||||
				r:      2.0,
 | 
			
		||||
			},
 | 
			
		||||
			want:  1.,
 | 
			
		||||
			want1: 1.0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "right",
 | 
			
		||||
			args: args{
 | 
			
		||||
				arr:    []byte{0, 0, 0, 0, 0, 0, 0, 5, 10, 15, 20, 40, 100, 60, 5},
 | 
			
		||||
				n:      15,
 | 
			
		||||
				offset: -1,
 | 
			
		||||
				r:      2.0,
 | 
			
		||||
			},
 | 
			
		||||
			want:  0.7142857142857142,
 | 
			
		||||
			want1: 0.39215686274509803,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			got, got1 := LinearBin(tt.args.arr, tt.args.n, tt.args.offset, tt.args.r)
 | 
			
		||||
			if got != tt.want {
 | 
			
		||||
				t.Errorf("linearBin() got = %v, want %v", got, tt.want)
 | 
			
		||||
			}
 | 
			
		||||
			if got1 != tt.want1 {
 | 
			
		||||
				t.Errorf("linearBin() got1 = %v, want %v", got1, tt.want1)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user