parent
91758cfffa
commit
b31e2632d0
@ -90,7 +90,7 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
|||||||
if horizon > 0:
|
if horizon > 0:
|
||||||
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2])).crop((0, horizon, img_width, img_height))) for d in data], 'f')
|
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2])).crop((0, horizon, img_width, img_height))) for d in data], 'f')
|
||||||
else:
|
else:
|
||||||
images = np.array( [img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
|
images = np.array([img_to_array(load_img(os.path.join(d[1], d[2]))) for d in data], 'f')
|
||||||
|
|
||||||
# slide images vs orders
|
# slide images vs orders
|
||||||
if slide_size > 0:
|
if slide_size > 0:
|
||||||
@ -114,8 +114,18 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
|||||||
model_filepath = '/opt/ml/model/model_other'
|
model_filepath = '/opt/ml/model/model_other'
|
||||||
if model_type == MODEL_CATEGORICAL:
|
if model_type == MODEL_CATEGORICAL:
|
||||||
model_filepath = '/opt/ml/model/model_cat'
|
model_filepath = '/opt/ml/model/model_cat'
|
||||||
|
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
||||||
|
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||||
|
loss = {'angle_out': 'categorical_crossentropy', }
|
||||||
|
optimizer = 'adam'
|
||||||
elif model_type == MODEL_LINEAR:
|
elif model_type == MODEL_LINEAR:
|
||||||
model_filepath = '/opt/ml/model/model_lin'
|
model_filepath = '/opt/ml/model/model_lin'
|
||||||
|
angle_cat_array = np.array([a for a in angle_array])
|
||||||
|
model = default_linear(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
||||||
|
loss = 'mse'
|
||||||
|
optimizer = 'rmsprop'
|
||||||
|
else:
|
||||||
|
raise Exception("invalid model type")
|
||||||
|
|
||||||
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
|
save_best = callbacks.ModelCheckpoint(model_filepath, monitor='val_loss', verbose=1,
|
||||||
save_best_only=True, mode='min')
|
save_best_only=True, mode='min')
|
||||||
@ -128,14 +138,8 @@ def train(model_type: str, batch_size: int, slide_size: int, img_height: int, im
|
|||||||
# categorical output of the angle
|
# categorical output of the angle
|
||||||
callbacks_list = [save_best, early_stop, logs]
|
callbacks_list = [save_best, early_stop, logs]
|
||||||
|
|
||||||
angle_cat_array = np.array([linear_bin(float(a)) for a in angle_array])
|
model.compile(optimizer=optimizer,
|
||||||
|
loss=loss,)
|
||||||
#model = default_model(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
|
||||||
model = default_categorical(input_shape=(img_height - horizon, img_width, img_depth), drop=drop)
|
|
||||||
|
|
||||||
model.compile(optimizer='adam',
|
|
||||||
loss={'angle_out': 'categorical_crossentropy', },
|
|
||||||
loss_weights={'angle_out': 0.9})
|
|
||||||
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
|
model.fit({'img_in': images}, {'angle_out': angle_cat_array, }, batch_size=batch_size,
|
||||||
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
epochs=100, verbose=1, validation_split=0.2, shuffle=True, callbacks=callbacks_list)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user