parent
91758cfffa
commit
b31e2632d0
@ -90,7 +90,7 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
||||
if horizon > 0:
|
||||
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2])).crop((0, horizon, img_width, img_height))) for d in data], 'f')
|
||||
else:
|
||||
images = np.array( [img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
|
||||
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
|
||||
|
||||
# slide images vs orders
|
||||
if slide_size > 0:
|
||||
@ -114,8 +114,18 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
||||
model_filepath = '/opt/ml/model/model_other'
|
||||
if model_type == MODEL_CATEGORICAL:
|
||||
model_filepath = '/opt/ml/model/model_cat'
|
||||
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
||||
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||
loss = {'angle_out': 'categorical_crossentropy', }
|
||||
optimizer = 'adam'
|
||||
elif model_type == MODEL_LINEAR:
|
||||
model_filepath = '/opt/ml/model/model_lin'
|
||||
angle_cat_array = np.array([a for a in angle_array])
|
||||
model = default_linear(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||
loss = 'mse'
|
||||
optimizer = 'rmsprop'
|
||||
else:
|
||||
raise Exception("invalid model type")
|
||||
|
||||
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
|
||||
save_best_only=True, mode='min')
|
||||
@ -128,14 +138,8 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
||||
# categorical output of the angle
|
||||
callbacks_list = [save_best, early_stop, logs]
|
||||
|
||||
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
||||
|
||||
#model = default_model(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss={'angle_out': 'categorical_crossentropy', },
|
||||
loss_weights={'angle_out': 0.9})
|
||||
model.compile(optimizer=optimizer,
|
||||
loss=loss,)
|
||||
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
|
||||
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user