parent
91758cfffa
commit
b31e2632d0
@ -114,8 +114,18 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
|||||||
model_filepath = '/opt/ml/model/model_other'
|
model_filepath = '/opt/ml/model/model_other'
|
||||||
if model_type == MODEL_CATEGORICAL:
|
if model_type == MODEL_CATEGORICAL:
|
||||||
model_filepath = '/opt/ml/model/model_cat'
|
model_filepath = '/opt/ml/model/model_cat'
|
||||||
|
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
||||||
|
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||||
|
loss = {'angle_out': 'categorical_crossentropy', }
|
||||||
|
optimizer = 'adam'
|
||||||
elif model_type == MODEL_LINEAR:
|
elif model_type == MODEL_LINEAR:
|
||||||
model_filepath = '/opt/ml/model/model_lin'
|
model_filepath = '/opt/ml/model/model_lin'
|
||||||
|
angle_cat_array = np.array([a for a in angle_array])
|
||||||
|
model = default_linear(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||||
|
loss = 'mse'
|
||||||
|
optimizer = 'rmsprop'
|
||||||
|
else:
|
||||||
|
raise Exception("invalid model type")
|
||||||
|
|
||||||
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
|
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
|
||||||
save_best_only=True, mode='min')
|
save_best_only=True, mode='min')
|
||||||
@ -128,14 +138,8 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
|||||||
# categorical output of the angle
|
# categorical output of the angle
|
||||||
callbacks_list = [save_best, early_stop, logs]
|
callbacks_list = [save_best, early_stop, logs]
|
||||||
|
|
||||||
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
model.compile(optimizer=optimizer,
|
||||||
|
loss=loss,)
|
||||||
#model = default_model(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
|
||||||
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss={'angle_out': 'categorical_crossentropy', },
|
|
||||||
loss_weights={'angle_out': 0.9})
|
|
||||||
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
|
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
|
||||||
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user