From 415d39c5bb7860f684c024168edda0269b0ed62f Mon Sep 17 00:00:00 2001 From: Cyrille Nofficial Date: Wed, 19 Apr 2023 19:56:45 +0200 Subject: [PATCH] fix: bad typing declaration --- tf_container/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_container/train.py b/tf_container/train.py index 9ecee37..35cb8be 100644 --- a/tf_container/train.py +++ b/tf_container/train.py @@ -187,7 +187,7 @@ def train(model_type: str, record_field: str, batch_size: int, slide_size: int, # Save model for tensorflow using model.save(f'/opt/ml/model/model_{record_field.replace("_", "")}_{model_type}_{str(img_width)}x{str(img_height)}h{str(horizon)}') - def representative_dataset() -> typing.Generator[typing.List[tf.float32], typing.Any, None]: + def representative_dataset() -> typing.Generator[typing.List[float], typing.Any, None]: for d in tf.data.Dataset.from_tensor_slices(images).batch(1).take(100): yield [tf.dtypes.cast(d, tf.float32)]