Export tf model
This commit is contained in:
parent
9ec80414c9
commit
84a8b11942
@ -1,19 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import container_support as cs
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import zipfile
|
||||
from keras.preprocessing.image import load_img, img_to_array
|
||||
import numpy as np
|
||||
|
||||
from keras.layers import Input, Dense, merge
|
||||
from keras.models import Model
|
||||
from keras.layers import Convolution2D, MaxPooling2D, Reshape, BatchNormalization
|
||||
from keras.layers import Activation, Dropout, Flatten, Dense
|
||||
import container_support as cs
|
||||
import json
|
||||
import numpy as np
|
||||
import re
|
||||
import tensorflow as tf
|
||||
import zipfile
|
||||
from keras import backend as K
|
||||
from keras import callbacks
|
||||
from keras.layers import Convolution2D
|
||||
from keras.layers import Dropout, Flatten, Dense
|
||||
from keras.layers import Input
|
||||
from keras.models import Model
|
||||
from keras.preprocessing.image import load_img, img_to_array
|
||||
from tensorflow.python.client import device_lib
|
||||
|
||||
|
||||
@ -101,7 +102,12 @@ def train():
|
||||
patience=10,
|
||||
verbose=1,
|
||||
mode='auto')
|
||||
img_in = Input(shape=(128, 160, 3), name='img_in') # First layer, input layer, Shape comes from camera.py resolution, RGB
|
||||
# Only for export model to tensorflow
|
||||
sess = tf.Session()
|
||||
K.set_session(sess)
|
||||
|
||||
img_in = Input(shape=(128, 160, 3),
|
||||
name='img_in') # First layer, input layer, Shape comes from camera.py resolution, RGB
|
||||
x = img_in
|
||||
x = Convolution2D(24, (5,5), strides=(2,2), activation='relu')(x) # 24 features, 5 pixel x 5 pixel kernel (convolution, feauture) window, 2wx2h stride, relu activation
|
||||
x = Convolution2D(32, (5,5), strides=(2,2), activation='relu')(x) # 32 features, 5px5p kernel window, 2wx2h stride, relu activatiion
|
||||
@ -129,3 +135,10 @@ def train():
|
||||
'throttle_out': 'mean_absolute_error'},
|
||||
loss_weights={'angle_out': 0.9, 'throttle_out': .001})
|
||||
model.fit({'img_in':images},{'angle_out': angle_cat_array, 'throttle_out': throttle_array}, batch_size=32, epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
||||
# Save model for tensorflow using
|
||||
builder = tf.saved_model.builder.SavedModelBuilder("myModel")
|
||||
|
||||
# Tag the model, required for Go
|
||||
builder.add_meta_graph_and_variables(sess, ["myTag"])
|
||||
builder.save()
|
||||
sess.close()
|
||||
|
Loading…
Reference in New Issue
Block a user