Export tf model

This commit is contained in:
Cyrille Nofficial 2019-11-05 19:55:27 +01:00
parent 9ec80414c9
commit 84a8b11942

View File

@ -1,19 +1,20 @@
#!/usr/bin/env python3
import container_support as cs
import os
import json
import re
import zipfile
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
from keras.layers import Input, Dense, merge
from keras.models import Model
from keras.layers import Convolution2D, MaxPooling2D, Reshape, BatchNormalization
from keras.layers import Activation, Dropout, Flatten, Dense
import container_support as cs
import json
import numpy as np
import re
import tensorflow as tf
import zipfile
from keras import backend as K
from keras import callbacks
from keras.layers import Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras.layers import Input
from keras.models import Model
from keras.preprocessing.image import load_img, img_to_array
from tensorflow.python.client import device_lib
@ -101,7 +102,12 @@ def train():
patience=10,
verbose=1,
mode='auto')
img_in = Input(shape=(128, 160, 3), name='img_in') # First layer, input layer, Shape comes from camera.py resolution, RGB
# Only for export model to tensorflow
sess = tf.Session()
K.set_session(sess)
img_in = Input(shape=(128, 160, 3),
name='img_in') # First layer, input layer, Shape comes from camera.py resolution, RGB
x = img_in
x = Convolution2D(24, (5,5), strides=(2,2), activation='relu')(x) # 24 features, 5 pixel x 5 pixel kernel (convolution, feauture) window, 2wx2h stride, relu activation
x = Convolution2D(32, (5,5), strides=(2,2), activation='relu')(x) # 32 features, 5px5p kernel window, 2wx2h stride, relu activatiion
@ -129,3 +135,10 @@ def train():
'throttle_out': 'mean_absolute_error'},
loss_weights={'angle_out': 0.9, 'throttle_out': .001})
model.fit({'img_in':images},{'angle_out': angle_cat_array, 'throttle_out': throttle_array}, batch_size=32, epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
# Save model for tensorflow using
builder = tf.saved_model.builder.SavedModelBuilder("myModel")
# Tag the model, required for Go
builder.add_meta_graph_and_variables(sess, ["myTag"])
builder.save()
sess.close()