From de3f507e342eebd6f93ab982ec973f72bb1c98f2 Mon Sep 17 00:00:00 2001 From: tianhaodongbd Date: Mon, 21 Oct 2024 11:44:48 +0800 Subject: [PATCH] fix rand_like bug --- paddleapex/apex/utils/data_generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddleapex/apex/utils/data_generate.py b/paddleapex/apex/utils/data_generate.py index 25c09ce..5c43b85 100644 --- a/paddleapex/apex/utils/data_generate.py +++ b/paddleapex/apex/utils/data_generate.py @@ -290,11 +290,11 @@ def rand_like(data, seed=1234): os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) if isinstance(data, paddle.Tensor): - if data.dtype.name in ["BF16", "FP16"]: + if data.dtype.name in ["BF16", "FP16", "BFLOAT16", "FLOAT16"]: random_normals = numpy.random.randn(*data.shape) x = paddle.to_tensor(random_normals, dtype=data.dtype) return x - elif data.dtype.name in ["FP32", "FP64"]: + elif data.dtype.name in ["FP32", "FP64", "FLOAT32", "FLOAT64"]: random_normals = numpy.random.randn(*data.shape) x = paddle.to_tensor(random_normals, dtype=data.dtype) return x