First implementation
This commit is contained in:
21
vendor/github.com/mattn/go-pointer/LICENSE
generated
vendored
Normal file
21
vendor/github.com/mattn/go-pointer/LICENSE
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2019 Yasuhiro Matsumoto
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
29
vendor/github.com/mattn/go-pointer/README.md
generated
vendored
Normal file
29
vendor/github.com/mattn/go-pointer/README.md
generated
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
# go-pointer
|
||||
|
||||
Utility for cgo
|
||||
|
||||
## Usage
|
||||
|
||||
https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
|
||||
|
||||
In go 1.6, cgo argument can't be passed Go pointer.
|
||||
|
||||
```
|
||||
var s string
|
||||
C.pass_pointer(pointer.Save(&s))
|
||||
v := *(pointer.Restore(C.get_from_pointer()).(*string))
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
go get github.com/mattn/go-pointer
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
## Author
|
||||
|
||||
Yasuhiro Matsumoto (a.k.a mattn)
|
1
vendor/github.com/mattn/go-pointer/doc.go
generated
vendored
Normal file
1
vendor/github.com/mattn/go-pointer/doc.go
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
package pointer
|
57
vendor/github.com/mattn/go-pointer/pointer.go
generated
vendored
Normal file
57
vendor/github.com/mattn/go-pointer/pointer.go
generated
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
package pointer
|
||||
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.RWMutex
|
||||
store = map[unsafe.Pointer]interface{}{}
|
||||
)
|
||||
|
||||
func Save(v interface{}) unsafe.Pointer {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate real fake C pointer.
|
||||
// This pointer will not store any data, but will bi used for indexing purposes.
|
||||
// Since Go doest allow to cast dangling pointer to unsafe.Pointer, we do rally allocate one byte.
|
||||
// Why we need indexing, because Go doest allow C code to store pointers to Go data.
|
||||
var ptr unsafe.Pointer = C.malloc(C.size_t(1))
|
||||
if ptr == nil {
|
||||
panic("can't allocate 'cgo-pointer hack index pointer': ptr == nil")
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
store[ptr] = v
|
||||
mutex.Unlock()
|
||||
|
||||
return ptr
|
||||
}
|
||||
|
||||
func Restore(ptr unsafe.Pointer) (v interface{}) {
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mutex.RLock()
|
||||
v = store[ptr]
|
||||
mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func Unref(ptr unsafe.Pointer) {
|
||||
if ptr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
delete(store, ptr)
|
||||
mutex.Unlock()
|
||||
|
||||
C.free(ptr)
|
||||
}
|
19
vendor/github.com/mattn/go-tflite/.gitignore
generated
vendored
Normal file
19
vendor/github.com/mattn/go-tflite/.gitignore
generated
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
tmp
|
||||
*.exe
|
||||
*.dll
|
||||
*.bmp
|
||||
*.jpg
|
||||
*.png
|
||||
*.mp4
|
||||
#_example/grace_hopper.bmp
|
||||
#_example/grace_hopper.png
|
||||
#_example/main.exe
|
||||
#_example/notebook.png
|
||||
#_example/output_graph.tflite
|
||||
#_example/output_labels.txt
|
||||
#_example/webcam/aaa.mp4
|
||||
#_example/webcam/bbb.mp4
|
||||
#_example/webcam/libtensorflowlite_c.dll
|
||||
#_example/webcam/mobilenet_quant_v1_224.tflite
|
||||
#_example/webcam/webcam.exe
|
||||
.ipynb_checkpoints
|
21
vendor/github.com/mattn/go-tflite/LICENSE
generated
vendored
Normal file
21
vendor/github.com/mattn/go-tflite/LICENSE
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2019 Yasuhiro Matsumoto
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
49
vendor/github.com/mattn/go-tflite/Makefile.tflite
generated
vendored
Normal file
49
vendor/github.com/mattn/go-tflite/Makefile.tflite
generated
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
SRCS = \
|
||||
c_api.cc \
|
||||
c_api_experimental.cc
|
||||
|
||||
OBJS = $(subst .cc,.o,$(subst .cxx,.o,$(subst .cpp,.o,$(SRCS))))
|
||||
|
||||
TENSORFLOW_ROOT = $(shell go env GOPATH)/src/github.com/tensorflow/tensorflow
|
||||
CXXFLAGS := -fPIC -DTF_COMPILE_LIBRARY -I$(TENSORFLOW_ROOT) \
|
||||
-I$(TENSORFLOW_ROOT)/tensorflow/lite/tools/make/downloads/flatbuffers/include \
|
||||
-I$(TENSORFLOW_ROOT)/tensorflow/lite/tools/make/downloads/absl
|
||||
TARGET = libtensorflowlite_c
|
||||
ifeq ($(OS),Windows_NT)
|
||||
OS_ARCH = windows_x86_64
|
||||
TARGET_SHARED := $(TARGET).dll
|
||||
else
|
||||
ifeq ($(shell uname -s),Darwin)
|
||||
CXXFLAGS := -std=c++11 $(CXXFLAGS)
|
||||
OS_ARCH = osx_$(shell uname -m)
|
||||
else
|
||||
ifeq ($(shell uname -m),x86_64)
|
||||
OS_ARCH = linux_x86_64
|
||||
else
|
||||
ifeq ($(shell uname -m),armv6l)
|
||||
OS_ARCH = linux_armv6l
|
||||
else
|
||||
OS_ARCH = rpi_armv7l
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
TARGET_SHARED := $(TARGET).so
|
||||
endif
|
||||
LDFLAGS += -L$(TENSORFLOW_ROOT)/tensorflow/lite/tools/make/gen/$(OS_ARCH)/lib
|
||||
LIBS = -ltensorflow-lite
|
||||
|
||||
.SUFFIXES: .cpp .cxx .o
|
||||
|
||||
all : $(TARGET_SHARED)
|
||||
|
||||
$(TARGET_SHARED) : $(OBJS)
|
||||
g++ -shared -o $@ $(OBJS) $(LDFLAGS) $(LIBS)
|
||||
|
||||
.cxx.o :
|
||||
g++ -std=c++14 -c $(CXXFLAGS) -I. $< -o $@
|
||||
|
||||
.cpp.o :
|
||||
g++ -std=c++14 -c $(CXXFLAGS) -I. $< -o $@
|
||||
|
||||
clean :
|
||||
rm -f *.o $(TARGET_SHARED)
|
96
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
Normal file
96
vendor/github.com/mattn/go-tflite/README.md
generated
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
# go-tflite
|
||||
|
||||
Go binding for TensorFlow Lite
|
||||
|
||||

|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
model := tflite.NewModelFromFile("sin_model.tflite")
|
||||
if model == nil {
|
||||
log.Fatal("cannot load model")
|
||||
}
|
||||
defer model.Delete()
|
||||
|
||||
options := tflite.NewInterpreterOptions()
|
||||
defer options.Delete()
|
||||
|
||||
interpreter := tflite.NewInterpreter(model, options)
|
||||
defer interpreter.Delete()
|
||||
|
||||
interpreter.AllocateTensors()
|
||||
|
||||
v := float64(1.2) * math.Pi / 180.0
|
||||
input := interpreter.GetInputTensor(0)
|
||||
input.Float32s()[0] = float32(v)
|
||||
interpreter.Invoke()
|
||||
got := float64(interpreter.GetOutputTensor(0).Float32s()[0])
|
||||
```
|
||||
|
||||
See `_example` for more examples
|
||||
|
||||
## Requirements
|
||||
|
||||
* TensorFlow Lite - This release requires 2.2.0-rc3
|
||||
|
||||
## Tensorflow Installation
|
||||
|
||||
You must install Tensorflow Lite C API. Assuming the source is under /source/directory/tensorflow
|
||||
|
||||
```
|
||||
$ cd /source/directory/tensorflow
|
||||
$ bazel build --config opt --config monolithic tensorflow:libtensorflow_c.so
|
||||
```
|
||||
|
||||
Or to just compile the tensorflow lite libraries:
|
||||
```
|
||||
$ cd /some/path/tensorflow
|
||||
$ bazel build --config opt --config monolithic //tensorflow/lite:libtensorflowlite.so
|
||||
$ bazel build --config opt --config monolithic //tensorflow/lite/c:libtensorflowlite_c.so
|
||||
```
|
||||
|
||||
In order for go to find the headers you must set the CGO_CFLAGS environment variable for the source and libraries of tensorflow.
|
||||
If your libraries are not installed in a standard location, you must also give the go linker the path to the shared librares
|
||||
with the CGO_LDFLAGS environment variable.
|
||||
|
||||
```
|
||||
$ export CGO_CFLAGS=-I/source/directory/tensorflow
|
||||
$ export CGO_LDFLAGS=-L/path/to/tensorflow/libaries
|
||||
```
|
||||
|
||||
If you don't love bazel, you can try `Makefile.tflite`.
|
||||
Put this file as `Makefile` in `tensorflow/lite/c`, and run `make`.
|
||||
Sorry, this has not been test for Linux or Mac
|
||||
|
||||
Then run `go build` on some of the examples.
|
||||
|
||||
## Edge TPU
|
||||
To be able to compile and use the EdgeTPU delegate, you need to install the libraries from here:
|
||||
https://github.com/google-coral/edgetpu
|
||||
|
||||
There is also a deb package here:
|
||||
https://coral.withgoogle.com/docs/accelerator/get-started/#1-install-the-edge-tpu-runtime
|
||||
|
||||
The libraries from should be installed in a system wide library path like `/usr/local/lib`
|
||||
The include files should be installed somewhere that is accesable from your CGO include path
|
||||
|
||||
For x86:
|
||||
```
|
||||
cd /tmp && git clone https://github.com/google-coral/edgetpu.git && \
|
||||
cp edgetpu/libedgetpu/direct/k8/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1.0 && \
|
||||
ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so.1 && \
|
||||
ln -rs /usr/local/lib/libedgetpu.so.1.0 /usr/local/lib/libedgetpu.so && \
|
||||
mkdir -p /usr/local/include/libedgetpu && \
|
||||
cp edgetpu/libedgetpu/edgetpu.h /usr/local/include/edgetpu.h && \
|
||||
cp edgetpu/libedgetpu/edgetpu_c.h /usr/local/include/edgetpu_c.h && \
|
||||
rm -Rf edgetpu
|
||||
```
|
||||
|
||||
|
||||
## License
|
||||
MIT
|
||||
|
||||
## Author
|
||||
Yasuhrio Matsumoto (a.k.a. mattn)
|
||||
|
19
vendor/github.com/mattn/go-tflite/callback.go
generated
vendored
Normal file
19
vendor/github.com/mattn/go-tflite/callback.go
generated
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
package tflite
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/mattn/go-pointer"
|
||||
)
|
||||
|
||||
type callbackInfo struct {
|
||||
user_data interface{}
|
||||
f func(msg string, user_data interface{})
|
||||
}
|
||||
|
||||
//export _go_error_reporter
|
||||
func _go_error_reporter(user_data unsafe.Pointer, msg *C.char) {
|
||||
cb := pointer.Restore(user_data).(*callbackInfo)
|
||||
cb.f(C.GoString(msg), cb.user_data)
|
||||
}
|
14
vendor/github.com/mattn/go-tflite/delegates/delegates.go
generated
vendored
Normal file
14
vendor/github.com/mattn/go-tflite/delegates/delegates.go
generated
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
package delegates
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type ModifyGraphWithDelegater interface {
|
||||
ModifyGraphWithDelegate(Delegater)
|
||||
}
|
||||
|
||||
type Delegater interface {
|
||||
Delete()
|
||||
Ptr() unsafe.Pointer
|
||||
}
|
104
vendor/github.com/mattn/go-tflite/delegates/edgetpu/edgetpu.go
generated
vendored
Normal file
104
vendor/github.com/mattn/go-tflite/delegates/edgetpu/edgetpu.go
generated
vendored
Normal file
@ -0,0 +1,104 @@
|
||||
package edgetpu
|
||||
|
||||
// +build !windows
|
||||
|
||||
/*
|
||||
#ifndef GO_EDGETPU_H
|
||||
#include "edgetpu.go.h"
|
||||
#include <edgetpu_c.h>
|
||||
#endif
|
||||
#cgo LDFLAGS: -ledgetpu
|
||||
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mattn/go-tflite/delegates"
|
||||
)
|
||||
|
||||
const (
|
||||
// The Device Types
|
||||
TypeApexPCI DeviceType = C.EDGETPU_APEX_PCI
|
||||
TypeApexUSB DeviceType = C.EDGETPU_APEX_USB
|
||||
)
|
||||
|
||||
type DeviceType uint32
|
||||
|
||||
type Device struct {
|
||||
Type DeviceType
|
||||
Path string
|
||||
}
|
||||
|
||||
// There are no options
|
||||
type DelegateOptions struct {
|
||||
}
|
||||
|
||||
// Delegate is the tflite delegate
|
||||
type Delegate struct {
|
||||
d *C.TfLiteDelegate
|
||||
}
|
||||
|
||||
func New(device Device) delegates.Delegater {
|
||||
var d *C.TfLiteDelegate
|
||||
d = C.edgetpu_create_delegate(uint32(device.Type), C.CString(device.Path), nil, 0)
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
return &Delegate{
|
||||
d: d,
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the delegate
|
||||
func (d *Delegate) Delete() {
|
||||
C.edgetpu_free_delegate(d.d)
|
||||
}
|
||||
|
||||
// Return a pointer
|
||||
func (d *Delegate) Ptr() unsafe.Pointer {
|
||||
return unsafe.Pointer(d.d)
|
||||
}
|
||||
|
||||
// Version fetches the EdgeTPU runtime version information
|
||||
func Version() (string, error) {
|
||||
version := C.edgetpu_version()
|
||||
if version == nil {
|
||||
return "", fmt.Errorf("could not get version")
|
||||
}
|
||||
return C.GoString(version), nil
|
||||
}
|
||||
|
||||
// Verbosity sets the edgetpu verbosity
|
||||
func Verbosity(v int) {
|
||||
C.edgetpu_verbosity(C.int(v))
|
||||
}
|
||||
|
||||
// DeviceList fetches a list of devices
|
||||
func DeviceList() ([]Device, error) {
|
||||
// Fetch the list of devices
|
||||
var numDevices C.size_t
|
||||
cDevices := C.edgetpu_list_devices(&numDevices)
|
||||
|
||||
if cDevices == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Cast the result to a Go slice
|
||||
deviceSlice := (*[1024]C.struct_edgetpu_device)(unsafe.Pointer(cDevices))[:numDevices:numDevices]
|
||||
|
||||
// Convert the list to go struct
|
||||
var devices []Device
|
||||
for i := C.size_t(0); i < numDevices; i++ {
|
||||
devices = append(devices, Device{
|
||||
Type: DeviceType(deviceSlice[i]._type),
|
||||
Path: C.GoString(deviceSlice[i].path),
|
||||
})
|
||||
}
|
||||
|
||||
// Free the list
|
||||
C.edgetpu_free_devices(cDevices)
|
||||
|
||||
return devices, nil
|
||||
}
|
11
vendor/github.com/mattn/go-tflite/delegates/edgetpu/edgetpu.go.h
generated
vendored
Normal file
11
vendor/github.com/mattn/go-tflite/delegates/edgetpu/edgetpu.go.h
generated
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
#ifndef GO_EDGETPU_H
|
||||
#define GO_EDGETPU_H
|
||||
|
||||
#define _GNU_SOURCE
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdlib.h>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
#include <edgetpu_c.h>
|
||||
|
||||
#endif
|
258
vendor/github.com/mattn/go-tflite/tflite.go
generated
vendored
Normal file
258
vendor/github.com/mattn/go-tflite/tflite.go
generated
vendored
Normal file
@ -0,0 +1,258 @@
|
||||
package tflite
|
||||
|
||||
/*
|
||||
#ifndef GO_TFLITE_H
|
||||
#include "tflite.go.h"
|
||||
#endif
|
||||
#cgo LDFLAGS: -ltensorflowlite_c
|
||||
#cgo android LDFLAGS: -ldl
|
||||
#cgo linux,!android LDFLAGS: -ldl -lrt
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"github.com/mattn/go-pointer"
|
||||
"github.com/mattn/go-tflite/delegates"
|
||||
)
|
||||
|
||||
//go:generate stringer -type TensorType,Status -output type_string.go .
|
||||
|
||||
// Model is TfLiteModel.
|
||||
type Model struct {
|
||||
m *C.TfLiteModel
|
||||
}
|
||||
|
||||
// NewModel create new Model from buffer.
|
||||
func NewModel(model_data []byte) *Model {
|
||||
m := C.TfLiteModelCreate(C.CBytes(model_data), C.size_t(len(model_data)))
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &Model{m: m}
|
||||
}
|
||||
|
||||
// NewModelFromFile create new Model from file data.
|
||||
func NewModelFromFile(model_path string) *Model {
|
||||
ptr := C.CString(model_path)
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
m := C.TfLiteModelCreateFromFile(ptr)
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &Model{m: m}
|
||||
}
|
||||
|
||||
// Delete delete instance of model.
|
||||
func (m *Model) Delete() {
|
||||
if m != nil {
|
||||
C.TfLiteModelDelete(m.m)
|
||||
}
|
||||
}
|
||||
|
||||
// InterpreterOptions implement TfLiteInterpreterOptions.
|
||||
type InterpreterOptions struct {
|
||||
o *C.TfLiteInterpreterOptions
|
||||
}
|
||||
|
||||
// NewInterpreterOptions create new InterpreterOptions.
|
||||
func NewInterpreterOptions() *InterpreterOptions {
|
||||
o := C.TfLiteInterpreterOptionsCreate()
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return &InterpreterOptions{o: o}
|
||||
}
|
||||
|
||||
// SetNumThread set number of threads.
|
||||
func (o *InterpreterOptions) SetNumThread(num_threads int) {
|
||||
C.TfLiteInterpreterOptionsSetNumThreads(o.o, C.int32_t(num_threads))
|
||||
}
|
||||
|
||||
// SetErrorRepoter set a function of reporter.
|
||||
func (o *InterpreterOptions) SetErrorReporter(f func(string, interface{}), user_data interface{}) {
|
||||
C._TfLiteInterpreterOptionsSetErrorReporter(o.o, pointer.Save(&callbackInfo{
|
||||
user_data: user_data,
|
||||
f: f,
|
||||
}))
|
||||
}
|
||||
|
||||
func (o *InterpreterOptions) AddDelegate(d delegates.Delegater) {
|
||||
C.TfLiteInterpreterOptionsAddDelegate(o.o, (*C.TfLiteDelegate)(d.Ptr()))
|
||||
}
|
||||
|
||||
// Delete delete instance of InterpreterOptions.
|
||||
func (o *InterpreterOptions) Delete() {
|
||||
if o != nil {
|
||||
C.TfLiteInterpreterOptionsDelete(o.o)
|
||||
}
|
||||
}
|
||||
|
||||
// Interpreter implement TfLiteInterpreter.
|
||||
type Interpreter struct {
|
||||
i *C.TfLiteInterpreter
|
||||
}
|
||||
|
||||
// NewInterpreter create new Interpreter.
|
||||
func NewInterpreter(model *Model, options *InterpreterOptions) *Interpreter {
|
||||
var o *C.TfLiteInterpreterOptions
|
||||
if options != nil {
|
||||
o = options.o
|
||||
}
|
||||
i := C.TfLiteInterpreterCreate(model.m, o)
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
return &Interpreter{i: i}
|
||||
}
|
||||
|
||||
// Delete delete instance of Interpreter.
|
||||
func (i *Interpreter) Delete() {
|
||||
if i != nil {
|
||||
C.TfLiteInterpreterDelete(i.i)
|
||||
}
|
||||
}
|
||||
|
||||
// Tensor implement TfLiteTensor.
|
||||
type Tensor struct {
|
||||
t *C.TfLiteTensor
|
||||
}
|
||||
|
||||
// GetInputTensorCount return number of input tensors.
|
||||
func (i *Interpreter) GetInputTensorCount() int {
|
||||
return int(C.TfLiteInterpreterGetInputTensorCount(i.i))
|
||||
}
|
||||
|
||||
// GetInputTensor return input tensor specified by index.
|
||||
func (i *Interpreter) GetInputTensor(index int) *Tensor {
|
||||
t := C.TfLiteInterpreterGetInputTensor(i.i, C.int32_t(index))
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return &Tensor{t: t}
|
||||
}
|
||||
|
||||
// State implement TfLiteStatus.
|
||||
type Status int
|
||||
|
||||
const (
|
||||
OK Status = 0
|
||||
Error
|
||||
)
|
||||
|
||||
// ResizeInputTensor resize the tensor specified by index with dims.
|
||||
func (i *Interpreter) ResizeInputTensor(index int, dims []int32) Status {
|
||||
s := C.TfLiteInterpreterResizeInputTensor(i.i, C.int32_t(index), (*C.int32_t)(unsafe.Pointer(&dims[0])), C.int32_t(len(dims)))
|
||||
return Status(s)
|
||||
}
|
||||
|
||||
// AllocateTensor allocate tensors for the interpreter.
|
||||
func (i *Interpreter) AllocateTensors() Status {
|
||||
if i != nil {
|
||||
s := C.TfLiteInterpreterAllocateTensors(i.i)
|
||||
return Status(s)
|
||||
}
|
||||
return Error
|
||||
}
|
||||
|
||||
// Invoke invoke the task.
|
||||
func (i *Interpreter) Invoke() Status {
|
||||
s := C.TfLiteInterpreterInvoke(i.i)
|
||||
return Status(s)
|
||||
}
|
||||
|
||||
// GetOutputTensorCount return number of output tensors.
|
||||
func (i *Interpreter) GetOutputTensorCount() int {
|
||||
return int(C.TfLiteInterpreterGetOutputTensorCount(i.i))
|
||||
}
|
||||
|
||||
// GetOutputTensor return output tensor specified by index.
|
||||
func (i *Interpreter) GetOutputTensor(index int) *Tensor {
|
||||
t := C.TfLiteInterpreterGetOutputTensor(i.i, C.int32_t(index))
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return &Tensor{t: t}
|
||||
}
|
||||
|
||||
// TensorType is types of the tensor.
|
||||
type TensorType int
|
||||
|
||||
const (
|
||||
NoType TensorType = 0
|
||||
Float32 TensorType = 1
|
||||
Int32 TensorType = 2
|
||||
UInt8 TensorType = 3
|
||||
Int64 TensorType = 4
|
||||
String TensorType = 5
|
||||
Bool TensorType = 6
|
||||
Int16 TensorType = 7
|
||||
Complex64 TensorType = 8
|
||||
Int8 TensorType = 9
|
||||
)
|
||||
|
||||
// Type return TensorType.
|
||||
func (t *Tensor) Type() TensorType {
|
||||
return TensorType(C.TfLiteTensorType(t.t))
|
||||
}
|
||||
|
||||
// NumDims return number of dimensions.
|
||||
func (t *Tensor) NumDims() int {
|
||||
return int(C.TfLiteTensorNumDims(t.t))
|
||||
}
|
||||
|
||||
// Dim return dimension of the element specified by index.
|
||||
func (t *Tensor) Dim(index int) int {
|
||||
return int(C.TfLiteTensorDim(t.t, C.int32_t(index)))
|
||||
}
|
||||
|
||||
// Shape return shape of the tensor.
|
||||
func (t *Tensor) Shape() []int {
|
||||
shape := make([]int, t.NumDims())
|
||||
for i := 0; i < t.NumDims(); i++ {
|
||||
shape[i] = t.Dim(i)
|
||||
}
|
||||
return shape
|
||||
}
|
||||
|
||||
// ByteSize return byte size of the tensor.
|
||||
func (t *Tensor) ByteSize() uint {
|
||||
return uint(C.TfLiteTensorByteSize(t.t))
|
||||
}
|
||||
|
||||
// Data return pointer of buffer.
|
||||
func (t *Tensor) Data() unsafe.Pointer {
|
||||
return C.TfLiteTensorData(t.t)
|
||||
}
|
||||
|
||||
// Name return name of the tensor.
|
||||
func (t *Tensor) Name() string {
|
||||
return C.GoString(C.TfLiteTensorName(t.t))
|
||||
}
|
||||
|
||||
// QuantizationParams implement TfLiteQuantizationParams.
|
||||
type QuantizationParams struct {
|
||||
Scale float64
|
||||
ZeroPoint int
|
||||
}
|
||||
|
||||
// QuantizationParams return quantization parameters of the tensor.
|
||||
func (t *Tensor) QuantizationParams() QuantizationParams {
|
||||
q := C.TfLiteTensorQuantizationParams(t.t)
|
||||
return QuantizationParams{
|
||||
Scale: float64(q.scale),
|
||||
ZeroPoint: int(q.zero_point),
|
||||
}
|
||||
}
|
||||
|
||||
// CopyFromBuffer write buffer to the tensor.
|
||||
func (t *Tensor) CopyFromBuffer(b interface{}) Status {
|
||||
return Status(C.TfLiteTensorCopyFromBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
|
||||
}
|
||||
|
||||
// CopyToBuffer write buffer from the tensor.
|
||||
func (t *Tensor) CopyToBuffer(b interface{}) Status {
|
||||
return Status(C.TfLiteTensorCopyToBuffer(t.t, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
|
||||
}
|
24
vendor/github.com/mattn/go-tflite/tflite.go.h
generated
vendored
Normal file
24
vendor/github.com/mattn/go-tflite/tflite.go.h
generated
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
#ifndef GO_TFLITE_H
|
||||
#define GO_TFLITE_H
|
||||
|
||||
#define _GNU_SOURCE
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdlib.h>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
|
||||
extern void _go_error_reporter(void*, char*);
|
||||
|
||||
static void
|
||||
_error_reporter(void *user_data, const char* format, va_list args) {
|
||||
char *ptr;
|
||||
if (asprintf(&ptr, format, args)) {}
|
||||
_go_error_reporter(user_data, ptr);
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
static void
|
||||
_TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions* options, void* user_data) {
|
||||
TfLiteInterpreterOptionsSetErrorReporter(options, _error_reporter, user_data);
|
||||
}
|
||||
#endif
|
388
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
Normal file
388
vendor/github.com/mattn/go-tflite/tflite_experimental.go
generated
vendored
Normal file
@ -0,0 +1,388 @@
|
||||
package tflite
|
||||
|
||||
/*
|
||||
#ifndef GO_TFLITE_EXPERIMENTAL_H
|
||||
#include "tflite_experimental.go.h"
|
||||
#endif
|
||||
|
||||
typedef void* (*f_tflite_registration_init)(TfLiteContext* context, const char* buffer, size_t length);
|
||||
void* _tflite_registration_init(TfLiteContext* context, char* buffer, size_t length);
|
||||
|
||||
typedef void (*f_tflite_registration_free)(TfLiteContext* context, void* buffer);
|
||||
void _tflite_registration_free(TfLiteContext* context, void* buffer);
|
||||
|
||||
typedef TfLiteStatus (*f_tflite_registration_prepare)(TfLiteContext* context, TfLiteNode* node);
|
||||
TfLiteStatus _tflite_registration_prepare(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
typedef TfLiteStatus (*f_tflite_registration_invoke)(TfLiteContext* context, TfLiteNode* node);
|
||||
TfLiteStatus _tflite_registration_invoke(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
typedef const char* (*f_tflite_registration_profiling_string)(const TfLiteContext* context, const TfLiteNode* node);
|
||||
char* _tflite_registration_profiling_string(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
static TfLiteRegistration*
|
||||
_make_registration(void* o_init, void* o_free, void* o_prepare, void* o_invoke, void* o_profiling_string) {
|
||||
TfLiteRegistration* r = (TfLiteRegistration*)malloc(sizeof(TfLiteRegistration));
|
||||
r->init = (f_tflite_registration_init) o_init;
|
||||
r->free = (f_tflite_registration_free) o_free;
|
||||
r->prepare = (f_tflite_registration_prepare) o_prepare;
|
||||
r->invoke = (f_tflite_registration_invoke) o_invoke;
|
||||
r->profiling_string = (f_tflite_registration_profiling_string) o_profiling_string;
|
||||
return r;
|
||||
}
|
||||
|
||||
static void look_context(TfLiteContext *context) {
|
||||
context->tensors;
|
||||
TfLiteIntArray *plan = NULL;
|
||||
context->GetExecutionPlan(context, &plan);
|
||||
if (plan == NULL) return;
|
||||
int i;
|
||||
for (i = 0; i < plan->size; i++) {
|
||||
TfLiteNode *node = NULL;
|
||||
TfLiteRegistration *reg = NULL;
|
||||
context->GetNodeAndRegistration(context, i, &node, ®);
|
||||
printf("%s\n", reg->custom_name);
|
||||
}
|
||||
}
|
||||
|
||||
static void writeToTensorAsVector(TfLiteTensor *tensor, char *bytes, size_t size, int nelem) {
|
||||
static TfLiteIntArray dummy;
|
||||
TfLiteIntArray* new_shape = (TfLiteIntArray*)malloc(sizeof(dummy) + sizeof(dummy.data[0]) * 1);
|
||||
if (new_shape) {
|
||||
new_shape->size = 1;
|
||||
new_shape->data[0] = nelem;
|
||||
memcpy(new_shape->data, tensor->dims->data, tensor->dims->size * sizeof(int));
|
||||
}
|
||||
|
||||
// TfLiteTensorDataFree
|
||||
if (tensor->allocation_type == kTfLiteDynamic && tensor->data.raw) {
|
||||
free(tensor->data.raw);
|
||||
}
|
||||
tensor->data.raw = NULL;
|
||||
|
||||
if (tensor->dims) free(tensor->dims);
|
||||
if (tensor->quantization.type == kTfLiteAffineQuantization) {
|
||||
TfLiteAffineQuantization* q_params =
|
||||
(TfLiteAffineQuantization*)(tensor->quantization.params);
|
||||
if (q_params->scale) {
|
||||
free(q_params->scale);
|
||||
q_params->scale = NULL;
|
||||
}
|
||||
if (q_params->zero_point) {
|
||||
free(q_params->zero_point);
|
||||
q_params->zero_point = NULL;
|
||||
}
|
||||
free(q_params);
|
||||
}
|
||||
tensor->dims = new_shape;
|
||||
tensor->data.raw = bytes;
|
||||
tensor->bytes = size;
|
||||
tensor->allocation_type = kTfLiteMmapRo;
|
||||
|
||||
tensor->quantization.type = kTfLiteNoQuantization;
|
||||
tensor->quantization.params = NULL;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const sizeof_int32_t = 4
|
||||
|
||||
// ResetVariableTensors resets variable tensors.
|
||||
func (i *Interpreter) ResetVariableTensors() Status {
|
||||
return Status(C.TfLiteInterpreterResetVariableTensors(i.i))
|
||||
}
|
||||
|
||||
/*
|
||||
type Registration interface {
|
||||
}
|
||||
|
||||
func (o *InterpreterOptions) AddCustomOp(name string, reg *Registration, minVersion, maxVersion int) {
|
||||
ptr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
r := C._make_registration()
|
||||
C.TfLiteInterpreterOptionsAddCustomOp(o.o, ptr, r, C.int(minVersion), C.int(maxVersion))
|
||||
}
|
||||
|
||||
type registration struct {
|
||||
ccxt *C.TfLiteContext
|
||||
}
|
||||
|
||||
//export _tflite_registration_init
|
||||
func _tflite_registration_init(ccxt *C.TfLiteContext, buffer *C.char, length C.size_t) unsafe.Pointer {
|
||||
println("registration.init")
|
||||
C.look_context(ccxt)
|
||||
|
||||
//var executionPlan *TfLiteIntArray
|
||||
//status := ccxt.GetExecutionPlan(ccxt, &executionPlan)
|
||||
//if status != C.kTfLiteOk {
|
||||
//return nil
|
||||
//}
|
||||
//var registration *C.TfLiteRegistration
|
||||
//var node *C.TfLiteNode
|
||||
//for i := 0; i < executionPlan.size; i++ {
|
||||
//ccxt.GetNodeAndRegistration(ccxt, 0, &node, ®istration)
|
||||
//}
|
||||
|
||||
println(buffer, length)
|
||||
return nil
|
||||
}
|
||||
|
||||
//export _tflite_registration_free
|
||||
func _tflite_registration_free(ccxt *C.TfLiteContext, buffer unsafe.Pointer) {
|
||||
println("registration.free")
|
||||
}
|
||||
|
||||
//export _tflite_registration_prepare
|
||||
func _tflite_registration_prepare(ccxt *C.TfLiteContext, node *C.TfLiteNode) C.TfLiteStatus {
|
||||
println("registration.prepare")
|
||||
return C.kTfLiteOk
|
||||
}
|
||||
|
||||
//export _tflite_registration_invoke
|
||||
func _tflite_registration_invoke(ccxt *C.TfLiteContext, node *C.TfLiteNode) C.TfLiteStatus {
|
||||
println("registration.invoke")
|
||||
return C.kTfLiteOk
|
||||
}
|
||||
|
||||
//export _tflite_registration_profiling_string
|
||||
func _tflite_registration_profiling_string(ccxt *C.TfLiteContext, node *C.TfLiteNode) *C.char {
|
||||
println("registration.profiling_string")
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
|
||||
// ExtRegistration indicate registration structure.
|
||||
type ExpRegistration struct {
|
||||
Init unsafe.Pointer
|
||||
Free unsafe.Pointer
|
||||
Prepare unsafe.Pointer
|
||||
Invoke unsafe.Pointer
|
||||
ProfilingString unsafe.Pointer
|
||||
}
|
||||
|
||||
type BuiltinOperator int
|
||||
|
||||
const (
|
||||
BuiltinOperator_ADD BuiltinOperator = 0
|
||||
BuiltinOperator_AVERAGE_POOL_2D BuiltinOperator = 1
|
||||
BuiltinOperator_CONCATENATION BuiltinOperator = 2
|
||||
BuiltinOperator_CONV_2D BuiltinOperator = 3
|
||||
BuiltinOperator_DEPTHWISE_CONV_2D BuiltinOperator = 4
|
||||
BuiltinOperator_DEQUANTIZE BuiltinOperator = 6
|
||||
BuiltinOperator_EMBEDDING_LOOKUP BuiltinOperator = 7
|
||||
BuiltinOperator_FLOOR BuiltinOperator = 8
|
||||
BuiltinOperator_FULLY_CONNECTED BuiltinOperator = 9
|
||||
BuiltinOperator_HASHTABLE_LOOKUP BuiltinOperator = 10
|
||||
BuiltinOperator_L2_NORMALIZATION BuiltinOperator = 11
|
||||
BuiltinOperator_L2_POOL_2D BuiltinOperator = 12
|
||||
BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION BuiltinOperator = 13
|
||||
BuiltinOperator_LOGISTIC BuiltinOperator = 14
|
||||
BuiltinOperator_LSH_PROJECTION BuiltinOperator = 15
|
||||
BuiltinOperator_LSTM BuiltinOperator = 16
|
||||
BuiltinOperator_MAX_POOL_2D BuiltinOperator = 17
|
||||
BuiltinOperator_MUL BuiltinOperator = 18
|
||||
BuiltinOperator_RELU BuiltinOperator = 19
|
||||
BuiltinOperator_RELU_N1_TO_1 BuiltinOperator = 20
|
||||
BuiltinOperator_RELU6 BuiltinOperator = 21
|
||||
BuiltinOperator_RESHAPE BuiltinOperator = 22
|
||||
BuiltinOperator_RESIZE_BILINEAR BuiltinOperator = 23
|
||||
BuiltinOperator_RNN BuiltinOperator = 24
|
||||
BuiltinOperator_SOFTMAX BuiltinOperator = 25
|
||||
BuiltinOperator_SPACE_TO_DEPTH BuiltinOperator = 26
|
||||
BuiltinOperator_SVDF BuiltinOperator = 27
|
||||
BuiltinOperator_TANH BuiltinOperator = 28
|
||||
BuiltinOperator_CONCAT_EMBEDDINGS BuiltinOperator = 29
|
||||
BuiltinOperator_SKIP_GRAM BuiltinOperator = 30
|
||||
BuiltinOperator_CALL BuiltinOperator = 31
|
||||
BuiltinOperator_CUSTOM BuiltinOperator = 32
|
||||
BuiltinOperator_EMBEDDING_LOOKUP_SPARSE BuiltinOperator = 33
|
||||
BuiltinOperator_PAD BuiltinOperator = 34
|
||||
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN BuiltinOperator = 35
|
||||
BuiltinOperator_GATHER BuiltinOperator = 36
|
||||
BuiltinOperator_BATCH_TO_SPACE_ND BuiltinOperator = 37
|
||||
BuiltinOperator_SPACE_TO_BATCH_ND BuiltinOperator = 38
|
||||
BuiltinOperator_TRANSPOSE BuiltinOperator = 39
|
||||
BuiltinOperator_MEAN BuiltinOperator = 40
|
||||
BuiltinOperator_SUB BuiltinOperator = 41
|
||||
BuiltinOperator_DIV BuiltinOperator = 42
|
||||
BuiltinOperator_SQUEEZE BuiltinOperator = 43
|
||||
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM BuiltinOperator = 44
|
||||
BuiltinOperator_STRIDED_SLICE BuiltinOperator = 45
|
||||
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN BuiltinOperator = 46
|
||||
BuiltinOperator_EXP BuiltinOperator = 47
|
||||
BuiltinOperator_TOPK_V2 BuiltinOperator = 48
|
||||
BuiltinOperator_SPLIT BuiltinOperator = 49
|
||||
BuiltinOperator_LOG_SOFTMAX BuiltinOperator = 50
|
||||
BuiltinOperator_DELEGATE BuiltinOperator = 51
|
||||
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM BuiltinOperator = 52
|
||||
BuiltinOperator_CAST BuiltinOperator = 53
|
||||
BuiltinOperator_PRELU BuiltinOperator = 54
|
||||
BuiltinOperator_MAXIMUM BuiltinOperator = 55
|
||||
BuiltinOperator_ARG_MAX BuiltinOperator = 56
|
||||
BuiltinOperator_MINIMUM BuiltinOperator = 57
|
||||
BuiltinOperator_LESS BuiltinOperator = 58
|
||||
BuiltinOperator_NEG BuiltinOperator = 59
|
||||
BuiltinOperator_PADV2 BuiltinOperator = 60
|
||||
BuiltinOperator_GREATER BuiltinOperator = 61
|
||||
BuiltinOperator_GREATER_EQUAL BuiltinOperator = 62
|
||||
BuiltinOperator_LESS_EQUAL BuiltinOperator = 63
|
||||
BuiltinOperator_SELECT BuiltinOperator = 64
|
||||
BuiltinOperator_SLICE BuiltinOperator = 65
|
||||
BuiltinOperator_SIN BuiltinOperator = 66
|
||||
BuiltinOperator_TRANSPOSE_CONV BuiltinOperator = 67
|
||||
BuiltinOperator_SPARSE_TO_DENSE BuiltinOperator = 68
|
||||
BuiltinOperator_TILE BuiltinOperator = 69
|
||||
BuiltinOperator_EXPAND_DIMS BuiltinOperator = 70
|
||||
BuiltinOperator_EQUAL BuiltinOperator = 71
|
||||
BuiltinOperator_NOT_EQUAL BuiltinOperator = 72
|
||||
BuiltinOperator_LOG BuiltinOperator = 73
|
||||
BuiltinOperator_SUM BuiltinOperator = 74
|
||||
BuiltinOperator_SQRT BuiltinOperator = 75
|
||||
BuiltinOperator_RSQRT BuiltinOperator = 76
|
||||
BuiltinOperator_SHAPE BuiltinOperator = 77
|
||||
BuiltinOperator_POW BuiltinOperator = 78
|
||||
BuiltinOperator_ARG_MIN BuiltinOperator = 79
|
||||
BuiltinOperator_FAKE_QUANT BuiltinOperator = 80
|
||||
BuiltinOperator_REDUCE_PROD BuiltinOperator = 81
|
||||
BuiltinOperator_REDUCE_MAX BuiltinOperator = 82
|
||||
BuiltinOperator_PACK BuiltinOperator = 83
|
||||
BuiltinOperator_LOGICAL_OR BuiltinOperator = 84
|
||||
BuiltinOperator_ONE_HOT BuiltinOperator = 85
|
||||
BuiltinOperator_LOGICAL_AND BuiltinOperator = 86
|
||||
BuiltinOperator_LOGICAL_NOT BuiltinOperator = 87
|
||||
BuiltinOperator_UNPACK BuiltinOperator = 88
|
||||
BuiltinOperator_REDUCE_MIN BuiltinOperator = 89
|
||||
BuiltinOperator_FLOOR_DIV BuiltinOperator = 90
|
||||
BuiltinOperator_REDUCE_ANY BuiltinOperator = 91
|
||||
BuiltinOperator_SQUARE BuiltinOperator = 92
|
||||
BuiltinOperator_ZEROS_LIKE BuiltinOperator = 93
|
||||
BuiltinOperator_FILL BuiltinOperator = 94
|
||||
BuiltinOperator_FLOOR_MOD BuiltinOperator = 95
|
||||
BuiltinOperator_RANGE BuiltinOperator = 96
|
||||
BuiltinOperator_RESIZE_NEAREST_NEIGHBOR BuiltinOperator = 97
|
||||
BuiltinOperator_LEAKY_RELU BuiltinOperator = 98
|
||||
BuiltinOperator_SQUARED_DIFFERENCE BuiltinOperator = 99
|
||||
BuiltinOperator_MIRROR_PAD BuiltinOperator = 100
|
||||
BuiltinOperator_ABS BuiltinOperator = 101
|
||||
BuiltinOperator_SPLIT_V BuiltinOperator = 102
|
||||
BuiltinOperator_UNIQUE BuiltinOperator = 103
|
||||
BuiltinOperator_CEIL BuiltinOperator = 104
|
||||
BuiltinOperator_REVERSE_V2 BuiltinOperator = 105
|
||||
BuiltinOperator_ADD_N BuiltinOperator = 106
|
||||
BuiltinOperator_GATHER_ND BuiltinOperator = 107
|
||||
BuiltinOperator_COS BuiltinOperator = 108
|
||||
BuiltinOperator_WHERE BuiltinOperator = 109
|
||||
BuiltinOperator_RANK BuiltinOperator = 110
|
||||
BuiltinOperator_ELU BuiltinOperator = 111
|
||||
BuiltinOperator_REVERSE_SEQUENCE BuiltinOperator = 112
|
||||
BuiltinOperator_MATRIX_DIAG BuiltinOperator = 113
|
||||
BuiltinOperator_QUANTIZE BuiltinOperator = 114
|
||||
BuiltinOperator_MATRIX_SET_DIAG BuiltinOperator = 115
|
||||
BuiltinOperator_MIN BuiltinOperator = BuiltinOperator_ADD
|
||||
BuiltinOperator_MAX BuiltinOperator = BuiltinOperator_MATRIX_SET_DIAG
|
||||
)
|
||||
|
||||
// ExpAddBuiltinOp add builtin op specified by code and registration. Current implementation is work in progress.
|
||||
func (o *InterpreterOptions) ExpAddBuiltinOp(op BuiltinOperator, reg *ExpRegistration, minVersion, maxVersion int) {
|
||||
r := C._make_registration(
|
||||
reg.Init,
|
||||
reg.Free,
|
||||
reg.Prepare,
|
||||
reg.Invoke,
|
||||
reg.ProfilingString,
|
||||
)
|
||||
C.TfLiteInterpreterOptionsAddBuiltinOp(o.o, C.TfLiteBuiltinOperator(op), r, C.int(minVersion), C.int(maxVersion))
|
||||
}
|
||||
|
||||
// ExpAddCustomOp add custom op specified by name and registration. Current implementation is work in progress.
|
||||
func (o *InterpreterOptions) ExpAddCustomOp(name string, reg *ExpRegistration, minVersion, maxVersion int) {
|
||||
ptr := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
r := C._make_registration(
|
||||
reg.Init,
|
||||
reg.Free,
|
||||
reg.Prepare,
|
||||
reg.Invoke,
|
||||
reg.ProfilingString,
|
||||
)
|
||||
C.TfLiteInterpreterOptionsAddCustomOp(o.o, ptr, r, C.int(minVersion), C.int(maxVersion))
|
||||
}
|
||||
|
||||
// SetUseNNAPI enable or disable the NN API for the interpreter (true to enable).
|
||||
func (o *InterpreterOptions) SetUseNNAPI(enable bool) {
|
||||
C.TfLiteInterpreterOptionsSetUseNNAPI(o.o, C.bool(enable))
|
||||
}
|
||||
|
||||
// DynamicBuffer is buffer hold multiple strings.
|
||||
type DynamicBuffer struct {
|
||||
data bytes.Buffer
|
||||
offset []int
|
||||
}
|
||||
|
||||
// AddString append to the dynamic buffer.
|
||||
func (d *DynamicBuffer) AddString(s string) {
|
||||
b := []byte(s)
|
||||
d.data.Write(b)
|
||||
if len(d.offset) == 0 {
|
||||
d.offset = append(d.offset, len(b))
|
||||
} else {
|
||||
d.offset = append(d.offset, d.offset[len(d.offset)-1]+len(b))
|
||||
}
|
||||
}
|
||||
|
||||
// WriteToTensorAsVector write buffer into the tensor as vector.
|
||||
func (d *DynamicBuffer) WriteToTensorAsVector(t *Tensor) {
|
||||
var out bytes.Buffer
|
||||
|
||||
b := make([]byte, 4)
|
||||
|
||||
// Allocate sufficient memory to tensor buffer.
|
||||
num_strings := len(d.offset)
|
||||
|
||||
// Set num of string
|
||||
binary.LittleEndian.PutUint32(b, uint32(num_strings))
|
||||
out.Write(b)
|
||||
|
||||
if num_strings > 0 {
|
||||
|
||||
// Set offset of strings.
|
||||
start := sizeof_int32_t + sizeof_int32_t*(num_strings+1)
|
||||
offset := start
|
||||
|
||||
binary.LittleEndian.PutUint32(b, uint32(offset))
|
||||
out.Write(b)
|
||||
|
||||
for i := 0; i < len(d.offset); i++ {
|
||||
offset := start + d.offset[i]
|
||||
binary.LittleEndian.PutUint32(b, uint32(offset))
|
||||
out.Write(b)
|
||||
}
|
||||
|
||||
// Copy data of strings.
|
||||
io.Copy(&out, &d.data)
|
||||
}
|
||||
|
||||
b = out.Bytes()
|
||||
C.writeToTensorAsVector(t.t, (*C.char)(unsafe.Pointer(&b[0])), C.size_t(len(b)), C.int(len(d.offset)))
|
||||
}
|
||||
|
||||
// GetString returns string in the string buffer.
|
||||
func (t *Tensor) GetString(index int) string {
|
||||
if t.Type() != String {
|
||||
return ""
|
||||
}
|
||||
ptr := uintptr(t.Data())
|
||||
count := int(*(*C.int32_t)(unsafe.Pointer(ptr)))
|
||||
if index >= count {
|
||||
return ""
|
||||
}
|
||||
offset1 := int(*(*C.int32_t)(unsafe.Pointer(ptr + uintptr(4*(index+1)))))
|
||||
offset2 := int(*(*C.int32_t)(unsafe.Pointer(ptr + uintptr(4*(index+2)))))
|
||||
return string((*((*[1<<31 - 1]uint8)(unsafe.Pointer(ptr))))[offset1:offset2])
|
||||
}
|
10
vendor/github.com/mattn/go-tflite/tflite_experimental.go.h
generated
vendored
Normal file
10
vendor/github.com/mattn/go-tflite/tflite_experimental.go.h
generated
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef GO_TFLITE_H
|
||||
#define GO_TFLITE_H
|
||||
|
||||
#define _GNU_SOURCE
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <tensorflow/lite/c/c_api_experimental.h>
|
||||
#endif
|
198
vendor/github.com/mattn/go-tflite/tflite_type.go
generated
vendored
Normal file
198
vendor/github.com/mattn/go-tflite/tflite_type.go
generated
vendored
Normal file
@ -0,0 +1,198 @@
|
||||
package tflite
|
||||
|
||||
/*
|
||||
#ifndef GO_TFLITE_H
|
||||
#include "tflite.go.h"
|
||||
#endif
|
||||
*/
|
||||
import "C"
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrTypeMismatch is type mismatch.
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
// ErrBadTensor is bad tensor.
|
||||
ErrBadTensor = errors.New("bad tensor")
|
||||
)
|
||||
|
||||
// SetInt32s sets int32s.
|
||||
func (t *Tensor) SetInt32s(v []int32) error {
|
||||
if t.Type() != Int32 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize() / 4
|
||||
to := (*((*[1<<29 - 1]int32)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int32s returns int32s.
|
||||
func (t *Tensor) Int32s() []int32 {
|
||||
if t.Type() != Int32 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize() / 4
|
||||
return (*((*[1<<29 - 1]int32)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// SetFloat32s sets float32s.
|
||||
func (t *Tensor) SetFloat32s(v []float32) error {
|
||||
if t.Type() != Float32 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize() / 4
|
||||
to := (*((*[1<<29 - 1]float32)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Float32s returns float32s.
|
||||
func (t *Tensor) Float32s() []float32 {
|
||||
if t.Type() != Float32 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize() / 4
|
||||
return (*((*[1<<29 - 1]float32)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// Float32At returns float32 value located in the dimension.
|
||||
func (t *Tensor) Float32At(at ...int) float32 {
|
||||
pos := 0
|
||||
for i := 0; i < t.NumDims(); i++ {
|
||||
pos = pos*t.Dim(i) + at[i]
|
||||
}
|
||||
return t.Float32s()[pos]
|
||||
}
|
||||
|
||||
// SetUint8s sets uint8s.
|
||||
func (t *Tensor) SetUint8s(v []uint8) error {
|
||||
if t.Type() != UInt8 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize()
|
||||
to := (*((*[1<<29 - 1]uint8)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UInt8s returns uint8s.
|
||||
func (t *Tensor) UInt8s() []uint8 {
|
||||
if t.Type() != UInt8 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize()
|
||||
return (*((*[1<<29 - 1]uint8)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// SetInt64s sets int64s.
|
||||
func (t *Tensor) SetInt64s(v []int64) error {
|
||||
if t.Type() != Int64 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize() / 8
|
||||
to := (*((*[1<<28 - 1]int64)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64s returns int64s.
|
||||
func (t *Tensor) Int64s() []int64 {
|
||||
if t.Type() != Int64 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize() / 8
|
||||
return (*((*[1<<28 - 1]int64)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// SetInt16s sets int16s.
|
||||
func (t *Tensor) SetInt16s(v []int16) error {
|
||||
if t.Type() != Int16 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize() / 2
|
||||
to := (*((*[1<<29 - 1]int16)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int16s returns int16s.
|
||||
func (t *Tensor) Int16s() []int16 {
|
||||
if t.Type() != Int16 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize() / 2
|
||||
return (*((*[1<<29 - 1]int16)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// SetInt8s sets int8s.
|
||||
func (t *Tensor) SetInt8s(v []int8) error {
|
||||
if t.Type() != Int8 {
|
||||
return ErrTypeMismatch
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return ErrBadTensor
|
||||
}
|
||||
n := t.ByteSize()
|
||||
to := (*((*[1<<29 - 1]int8)(ptr)))[:n]
|
||||
copy(to, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int8s returns int8s.
|
||||
func (t *Tensor) Int8s() []int8 {
|
||||
if t.Type() != Int8 {
|
||||
return nil
|
||||
}
|
||||
ptr := C.TfLiteTensorData(t.t)
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
n := t.ByteSize()
|
||||
return (*((*[1<<29 - 1]int8)(ptr)))[:n]
|
||||
}
|
||||
|
||||
// String returns name of tensor.
|
||||
func (t *Tensor) String() string {
|
||||
return t.Name()
|
||||
}
|
27
vendor/github.com/mattn/go-tflite/type_string.go
generated
vendored
Normal file
27
vendor/github.com/mattn/go-tflite/type_string.go
generated
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
// Code generated by "stringer -type TensorType,Status -output type_string.go ."; DO NOT EDIT.
|
||||
|
||||
package tflite
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _TensorType_name = "NoTypeFloat32Int32UInt8Int64StringBoolInt16Complex64Int8"
|
||||
|
||||
var _TensorType_index = [...]uint8{0, 6, 13, 18, 23, 28, 34, 38, 43, 52, 56}
|
||||
|
||||
func (i TensorType) String() string {
|
||||
if i < 0 || i >= TensorType(len(_TensorType_index)-1) {
|
||||
return "TensorType(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _TensorType_name[_TensorType_index[i]:_TensorType_index[i+1]]
|
||||
}
|
||||
|
||||
const _Status_name = "OK"
|
||||
|
||||
var _Status_index = [...]uint8{0, 2}
|
||||
|
||||
func (i Status) String() string {
|
||||
if i < 0 || i >= Status(len(_Status_index)-1) {
|
||||
return "Status(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Status_name[_Status_index[i]:_Status_index[i+1]]
|
||||
}
|
Reference in New Issue
Block a user