diff --git a/src/tf_container/train_entry_point.py b/src/tf_container/train_entry_point.py index 7f91ac0..a27ff80 100644 --- a/src/tf_container/train_entry_point.py +++ b/src/tf_container/train_entry_point.py @@ -77,9 +77,9 @@ def train(): return arr logs = callbacks.TensorBoard(log_dir='logs', histogram_freq=0, write_graph=True, write_images=True) - save_best = callbacks.ModelCheckpoint('/opt/ml/model/model_cat', monitor='angle_out_loss', verbose=1, + save_best = callbacks.ModelCheckpoint('/opt/ml/model/model_cat', monitor='val_loss', verbose=1, save_best_only=True, mode='min') - early_stop = callbacks.EarlyStopping(monitor='angle_out_loss', + early_stop = callbacks.EarlyStopping(monitor='val_loss', min_delta=.0005, patience=10, verbose=1,