[train-archive] Implement command to generate train.zip file

This commit is contained in:
Cyrille Nofficial 2020-02-16 19:14:38 +01:00
parent ec25a07928
commit 737d277ea5
32 changed files with 380 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"github.com/cyrilix/robocar-base/cli" "github.com/cyrilix/robocar-base/cli"
"github.com/cyrilix/robocar-tools/data"
"github.com/cyrilix/robocar-tools/part" "github.com/cyrilix/robocar-tools/part"
"github.com/cyrilix/robocar-tools/record" "github.com/cyrilix/robocar-tools/record"
"github.com/cyrilix/robocar-tools/video" "github.com/cyrilix/robocar-tools/video"
@ -14,6 +15,7 @@ import (
const ( const (
DefaultClientId = "robocar-tools" DefaultClientId = "robocar-tools"
DefaultTrainSliceSize = 0
) )
func main() { func main() {
@ -22,7 +24,9 @@ func main() {
var fps int var fps int
var frameTopic, objectsTopic, roadTopic, recordTopic string var frameTopic, objectsTopic, roadTopic, recordTopic string
var withObjects, withRoad bool var withObjects, withRoad bool
var jsonPath, imgPath string var recordsPath string
var trainArchiveName string
var trainSliceSize int
mqttQos := cli.InitIntFlag("MQTT_QOS", 0) mqttQos := cli.InitIntFlag("MQTT_QOS", 0)
_, mqttRetain := os.LookupEnv("MQTT_RETAIN") _, mqttRetain := os.LookupEnv("MQTT_RETAIN")
@ -31,6 +35,7 @@ func main() {
fmt.Printf("Usage of %s:\n", os.Args[0]) fmt.Printf("Usage of %s:\n", os.Args[0])
fmt.Printf(" display\n \tDisplay events on live frames\n") fmt.Printf(" display\n \tDisplay events on live frames\n")
fmt.Printf(" record \n \tRecord event for tensorflow training\n") fmt.Printf(" record \n \tRecord event for tensorflow training\n")
fmt.Printf(" train-archive \n \tGenerate zip archive for training \n")
} }
displayFlags := flag.NewFlagSet("display", flag.ExitOnError) displayFlags := flag.NewFlagSet("display", flag.ExitOnError)
@ -48,8 +53,16 @@ func main() {
recordFlags := flag.NewFlagSet("record", flag.ExitOnError) recordFlags := flag.NewFlagSet("record", flag.ExitOnError)
cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain) cli.InitMqttFlagSet(recordFlags, DefaultClientId, &mqttBroker, &username, &password, &clientId, &mqttQos, &mqttRetain)
recordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set") recordFlags.StringVar(&recordTopic, "mqtt-topic-records", os.Getenv("MQTT_TOPIC_RECORDS"), "Mqtt topic that contains record data for training, use MQTT_TOPIC_RECORDS if args not set")
recordFlags.StringVar(&jsonPath, "record-json-path", os.Getenv("RECORD_JSON_PATH"), "Path where to write json files, use RECORD_JSON_PATH if args not set") recordFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where to write records files, use RECORD_PATH if args not set")
recordFlags.StringVar(&imgPath, "record-image-path", os.Getenv("RECORD_IMAGE_PATH"), "Path where to write jpeg files, use RECORD_IMAGE_PATH if args not set")
trainArchiveFlags := flag.NewFlagSet("train-archive", flag.ExitOnError)
err := cli.SetIntDefaultValueFromEnv(&trainSliceSize, "TRAIN_SLICE_SIZE", DefaultTrainSliceSize)
if err != nil {
log.Printf("unable to parse horizon value arg: %v", err)
}
trainArchiveFlags.StringVar(&recordsPath, "record-path", os.Getenv("RECORD_PATH"), "Path where records files are stored, use RECORD_PATH if args not set")
trainArchiveFlags.StringVar(&trainArchiveName, "output", os.Getenv("TRAIN_ARCHIVE_NAME"), "Zip archive file name, use TRAIN_ARCHIVE_NAME if args not set")
trainArchiveFlags.IntVar(&trainSliceSize, "slice-size", trainSliceSize, "Number of record to shift with image, use TRAIN_SLICE_SIZE if args not set")
flag.Parse() flag.Parse()
@ -79,35 +92,44 @@ func main() {
log.Fatalf("unable to connect to mqtt bus: %v", err) log.Fatalf("unable to connect to mqtt bus: %v", err)
} }
defer client.Disconnect(50) defer client.Disconnect(50)
runRecord(client, jsonPath, imgPath, recordTopic) runRecord(client, recordsPath, recordTopic)
case trainArchiveFlags.Name():
if err := trainArchiveFlags.Parse(os.Args[2:]); err == flag.ErrHelp {
trainArchiveFlags.PrintDefaults()
os.Exit(0)
}
runTrainArchive(recordsPath, trainArchiveName, 2)
default: default:
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(1) os.Exit(1)
} }
cmd := flag.Arg(1)
switch cmd {
case "display":
case "record":
default:
log.Errorf("invalid command: %v", cmd)
} }
func runRecord(client mqtt.Client, recordsDir, recordTopic string) {
r, err := record.New(client, recordsDir, recordTopic)
if err != nil {
log.Fatalf("unable to init record part: %v", err)
} }
func runRecord(client mqtt.Client, jsonDir, imgDir string, recordTopic string) {
r := record.New(client, jsonDir, imgDir, recordTopic)
defer r.Stop() defer r.Stop()
cli.HandleExit(r) cli.HandleExit(r)
err := r.Start() err = r.Start()
if err != nil { if err != nil {
log.Fatalf("unable to start service: %v", err) log.Fatalf("unable to start service: %v", err)
} }
} }
func runTrainArchive(basedir, archiveName string, sliceSize int) {
err := data.BuildArchive(basedir, archiveName, sliceSize)
if err != nil {
log.Fatalf("unable to build archive file %v: %v", archiveName, err)
}
}
func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) { func runDisplay(client mqtt.Client, framePath string, frameTopic string, fps int, objectsTopic string, roadTopic string, withObjects bool, withRoad bool) {
if framePath != "" { if framePath != "" {

179
data/data.go Normal file
View File

@ -0,0 +1,179 @@
package data
import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"github.com/cyrilix/robocar-tools/record"
log "github.com/sirupsen/logrus"
"io/ioutil"
"os"
"path"
"regexp"
)
var camSubDir = "cam"
func BuildArchive(basedir string, archiveName string, sliceSize int) error {
dirItems, err := ioutil.ReadDir(basedir)
if err != nil {
return fmt.Errorf("unable to list directory in %v dir: %v", basedir, err)
}
imgCams := make([]string, 0)
records := make([]string, 0)
for _, dirItem := range dirItems {
log.Debugf("process %v directory", dirItem)
imgDir := path.Join(basedir, dirItem.Name(), camSubDir)
imgs, err := ioutil.ReadDir(imgDir)
if err != nil {
return fmt.Errorf("unable to list cam images in directory %v: %v", imgDir, err)
}
for _, img := range imgs {
idx, err := indexFromFile(img.Name())
if err != nil {
return fmt.Errorf("unable to find index in cam image name %v: %v", img.Name(), err)
}
log.Debugf("found image with index %v", idx)
records = append(records, path.Join(basedir, dirItem.Name(), fmt.Sprintf(record.RecorNameFormat, idx)))
imgCams = append(imgCams, path.Join(basedir, dirItem.Name(), camSubDir, img.Name()))
}
}
if sliceSize > 0{
imgCams, records, err = applySlice(imgCams, records, sliceSize)
}
content, err := buildArchiveContent(&imgCams, &records)
if err != nil {
return fmt.Errorf("unable to build archive: %v", err)
}
err = ioutil.WriteFile(archiveName, *content, os.FileMode(0755))
if err != nil {
return fmt.Errorf("unable to write archive content to disk: %v", err)
}
return nil
}
func applySlice(imgCams []string, records []string, sliceSize int) ([]string, []string, error) {
// Add sliceSize images shift
i := imgCams[:len(imgCams)-sliceSize]
r := records[sliceSize:]
return i, r, nil
}
var indexRegexp *regexp.Regexp
func init() {
re, err := regexp.Compile("image_array_(?P<idx>[0-9]+)\\.jpg$")
if err != nil {
log.Fatalf("unable to compile regex: %v", err)
}
indexRegexp = re
}
func indexFromFile(fileName string) (string, error) {
matches := findNamedMatches(indexRegexp, fileName)
if matches["idx"] == "" {
return "", fmt.Errorf("no index in filename")
}
return matches["idx"], nil
}
func findNamedMatches(regex *regexp.Regexp, str string) map[string]string {
match := regex.FindStringSubmatch(str)
results := map[string]string{}
for i, name := range match {
results[regex.SubexpNames()[i]] = name
}
return results
}
func buildArchiveContent(imgFiles *[]string, recordFiles *[]string) (*[]byte, error) {
// Create a buffer to write our archive to.
buf := new(bytes.Buffer)
// Create a new zip archive.
w := zip.NewWriter(buf)
err := addJsonFiles(recordFiles, imgFiles, w)
if err != nil {
return nil, fmt.Errorf("unable to write json files in zip archive: %v", err)
}
err = addCamImages(imgFiles, w)
if err != nil {
return nil, fmt.Errorf("unable to cam files in zip archive: %v", err)
}
err = w.Close()
if err != nil {
return nil, fmt.Errorf("unable to build archive: %v", err)
}
content, err := ioutil.ReadAll(buf)
return &content, err
}
func addCamImages(imgFiles *[]string, w *zip.Writer) error {
for _, img := range *imgFiles {
imgContent, err := ioutil.ReadFile(img)
if err != nil {
return fmt.Errorf("unable to read img: %v", err)
}
_, imgName := path.Split(img)
err = addToArchive(w, imgName, &imgContent)
if err != nil {
return fmt.Errorf("unable to create new img entry in archive: %v", err)
}
}
return nil
}
func addJsonFiles(recordFiles *[]string, imgCam *[]string, w *zip.Writer) error {
for idx, r := range *recordFiles {
content, err := ioutil.ReadFile(r)
if err != nil {
return fmt.Errorf("unable to read json content: %v", err)
}
var rcd record.Record
err = json.Unmarshal(content, &rcd)
if err != nil {
return fmt.Errorf("unable to unmarshal record: %v", err)
}
_, camName := path.Split((*imgCam)[idx])
rcd.CamImageArray = camName
recordBytes, err := json.Marshal(&rcd)
if err != nil {
return fmt.Errorf("unable to marshal %v record: %v", rcd, err)
}
_, recordName := path.Split(r)
err = addToArchive(w, recordName, &recordBytes)
if err != nil {
return fmt.Errorf("unable to create new record in archive: %v", err)
}
}
return nil
}
func addToArchive(w *zip.Writer, name string, content *[]byte) error {
recordWriter, err := w.Create(name)
if err != nil {
return fmt.Errorf("unable to create new entry %v in archive: %v", name, err)
}
_, err = recordWriter.Write(*content)
if err != nil {
return fmt.Errorf("unable to add content in %v zip archive: %v", name, err)
}
return nil
}

118
data/data_test.go Normal file
View File

@ -0,0 +1,118 @@
package data
import (
"archive/zip"
"encoding/json"
"fmt"
"github.com/cyrilix/robocar-tools/record"
log "github.com/sirupsen/logrus"
"io/ioutil"
"os"
"path"
"strings"
"testing"
)
func TestBuildArchive(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "buildarchive")
if err != nil {
t.Fatalf("unable to make tmpdir: %v", err)
}
defer func() {
err := os.RemoveAll(tmpDir)
if err != nil {
log.Warnf("unable to remove tempdir %v: %v", tmpDir, err)
}
}()
archive := path.Join(tmpDir, "train.zip")
expectedRecordFiles, expectedImgFiles := expectedFiles()
err = BuildArchive("testdata", archive, 0)
if err != nil {
t.Errorf("unable to build archive: %v", err)
}
r, err := zip.OpenReader(archive)
if err != nil {
t.Errorf("unable to read archive, %v", err)
}
defer r.Close()
if len(r.File) != len(expectedImgFiles)+len(expectedRecordFiles) {
t.Errorf("bad number of files in archive: %v, wants %v", len(r.File), len(expectedImgFiles)+len(expectedRecordFiles))
}
// Iterate through the files in the archive,
// printing some of their contents.
for _, f := range r.File {
filename := f.Name
if filename[len(filename)-4:] == "json" {
expectedRecordFiles[filename] = true
expectedtImgName := strings.Replace(filename, "record", "cam-image_array", 1)
expectedtImgName = strings.Replace(expectedtImgName, "json", "jpg", 1)
checkJsonContent(t, f, expectedtImgName)
continue
}
if filename[len(filename)-3:] == "jpg" {
expectedImgFiles[filename] = true
continue
}
t.Errorf("unexpected file in archive: %v", filename)
}
checkAllFilesAreFoundInArchive(expectedRecordFiles, t, expectedImgFiles)
}
func checkAllFilesAreFoundInArchive(expectedRecordFiles map[string]bool, t *testing.T, expectedImgFiles map[string]bool) {
for f, found := range expectedRecordFiles {
if !found {
t.Errorf("%v not found in archive", f)
}
}
for f, found := range expectedImgFiles {
if !found {
t.Errorf("%v not found in archive", f)
}
}
}
func checkJsonContent(t *testing.T, f *zip.File, expectedCamImage string) {
rc, err := f.Open()
if err != nil {
t.Errorf("unable to read file content of %v: %v", f.Name, err)
}
defer rc.Close()
content, err := ioutil.ReadAll(rc)
if err != nil {
t.Errorf("%v has invalid json content: %v", f.Name, err)
}
var rcd record.Record
err = json.Unmarshal(content, &rcd)
if err != nil {
t.Errorf("unable to unmarshal json content of%v: %v", f.Name, err)
}
if rcd.CamImageArray != expectedCamImage {
t.Errorf("record %v: invalid image ref: %v, wants %v", f.Name, rcd.CamImageArray, expectedCamImage)
}
if rcd.UserAngle == 0. {
t.Errorf("record %v: user angle has not been initialised", f.Name)
}
}
func expectedFiles() (map[string]bool, map[string]bool) {
expectedRecordFiles := make(map[string]bool)
expectedImgFiles := make(map[string]bool)
for i := 1; i <= 8; i++ {
expectedRecordFiles[fmt.Sprintf("record_%07d.json", i)] = false
expectedImgFiles[fmt.Sprintf("cam-image_array_%07d.jpg", i)] = false
}
for i := 101; i <= 106; i++ {
expectedRecordFiles[fmt.Sprintf("record_%07d.json", i)] = false
expectedImgFiles[fmt.Sprintf("cam-image_array_%07d.jpg", i)] = false
}
return expectedRecordFiles, expectedImgFiles
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000001.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.045098066,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000002.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.045098066,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000003.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000004.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000005.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.043137312,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000006.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000007.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.045098066,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000008.jpg"}

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000101.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.045098066,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000102.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.045098066,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000103.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000104.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.04117644,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000105.jpg"}

View File

@ -0,0 +1 @@
{"user/angle":0.043137312,"cam/image_array":"/tmp/record//2020021819-3/cam/cam-image_array_0000106.jpg"}

View File

@ -12,24 +12,29 @@ import (
"os" "os"
) )
func New(client mqtt.Client, jsonDir, imgDir string, recordTopic string) *Recorder { func New(client mqtt.Client, recordsDir, recordTopic string) (*Recorder, error) {
err := os.MkdirAll(recordsDir, os.FileMode(0755))
if err != nil {
return nil, fmt.Errorf("unable to create %v directory: %v", recordsDir, err)
}
return &Recorder{ return &Recorder{
client: client, client: client,
jsonDir: jsonDir, recordsDir: recordsDir,
imgDir: imgDir,
recordTopic: recordTopic, recordTopic: recordTopic,
cancel: make(chan interface{}), cancel: make(chan interface{}),
} }, nil
} }
type Recorder struct { type Recorder struct {
client mqtt.Client client mqtt.Client
jsonDir, imgDir string recordsDir string
recordTopic string recordTopic string
cancel chan interface{} cancel chan interface{}
} }
var RecorNameFormat = "record_%s.json"
func (r *Recorder) Start() error { func (r *Recorder) Start() error {
err := service.RegisterCallback(r.client, r.recordTopic, r.onRecordMsg) err := service.RegisterCallback(r.client, r.recordTopic, r.onRecordMsg)
if err != nil { if err != nil {
@ -51,16 +56,30 @@ func (r *Recorder) onRecordMsg(_ mqtt.Client, message mqtt.Message) {
log.Errorf("unable to unmarshal protobuf %T: %v", msg, err) log.Errorf("unable to unmarshal protobuf %T: %v", msg, err)
return return
} }
fmt.Printf("record %s: %s\r", msg.GetRecordSet(), msg.GetFrame().GetId().GetId())
os.MkdirAll() recordDir := fmt.Sprintf("%s/%s", r.recordsDir, msg.GetRecordSet())
imgName := fmt.Sprintf("%s/%s/cam-image_array_%s.jpg", r.imgDir, msg.GetRecordSet(), msg.GetFrame().GetId().GetId())
err = ioutil.WriteFile(imgName, msg.GetFrame().GetFrame(), 0755) imgDir := fmt.Sprintf("%s/cam", recordDir)
imgName := fmt.Sprintf("%s/cam-image_array_%s.jpg", imgDir, msg.GetFrame().GetId().GetId())
err = os.MkdirAll(imgDir, os.FileMode(0755))
if err != nil { if err != nil {
log.Errorf("unable to write json file %v: %v", imgName, err) log.Errorf("unable to create %v directory: %v", imgDir, err)
return
}
err = ioutil.WriteFile(imgName, msg.GetFrame().GetFrame(), os.FileMode(0755))
if err != nil {
log.Errorf("unable to write img file %v: %v", imgName, err)
return return
} }
recordName := fmt.Sprintf("record_%s.jpg", msg.GetFrame().GetId().GetId()) jsonDir := fmt.Sprintf("%s/", recordDir)
recordName := fmt.Sprintf("%s/%s", jsonDir, fmt.Sprintf(RecorNameFormat, msg.GetFrame().GetId().GetId()))
err = os.MkdirAll(jsonDir, os.FileMode(0755))
if err != nil {
log.Errorf("unable to create %v directory: %v", jsonDir, err)
return
}
record := Record{ record := Record{
UserAngle: msg.GetSteering().GetSteering(), UserAngle: msg.GetSteering().GetSteering(),
CamImageArray: imgName, CamImageArray: imgName,
@ -70,7 +89,6 @@ func (r *Recorder) onRecordMsg(_ mqtt.Client, message mqtt.Message) {
log.Errorf("unable to marshal json content: %v", err) log.Errorf("unable to marshal json content: %v", err)
return return
} }
err = ioutil.WriteFile(recordName, jsonBytes, 0755) err = ioutil.WriteFile(recordName, jsonBytes, 0755)
if err != nil { if err != nil {
log.Errorf("unable to write json file %v: %v", recordName, err) log.Errorf("unable to write json file %v: %v", recordName, err)