diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index c745b9d1..86b9bd8a 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -122,7 +122,7 @@ def all_reduce( return object # wrap object in a tensor - tensor = torch.Tensor(object) + tensor = torch.Tensor(object, device="cuda") # determine operation if operation == "sum":