feat: implement linear model

refs robocars/robocar-setup#5
This commit is contained in:
Cyrille Nofficial 2022-06-10 11:46:27 +02:00
parent 91758cfffa
commit b31e2632d0

View File

@ -90,7 +90,7 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
if horizon > 0:
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2])).crop((0, horizon, img_width, img_height))) for d in data], 'f')
else:
images = np.array( [img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
# slide images vs orders
if slide_size > 0:
@ -114,8 +114,18 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
model_filepath = '/opt/ml/model/model_other'
if model_type == MODEL_CATEGORICAL:
model_filepath = '/opt/ml/model/model_cat'
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
loss = {'angle_out': 'categorical_crossentropy', }
optimizer = 'adam'
elif model_type == MODEL_LINEAR:
model_filepath = '/opt/ml/model/model_lin'
angle_cat_array = np.array([a for a in angle_array])
model = default_linear(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
loss = 'mse'
optimizer = 'rmsprop'
else:
raise Exception("invalid model type")
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
save_best_only=True, mode='min')
@ -128,14 +138,8 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
# categorical output of the angle
callbacks_list = [save_best, early_stop, logs]
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
#model = default_model(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
model.compile(optimizer='adam',
loss={'angle_out': 'categorical_crossentropy', },
loss_weights={'angle_out': 0.9})
model.compile(optimizer=optimizer,
loss=loss,)
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)