Skip to content

Commit

Permalink
[spmd] quick fix on batch input view issue (pytorch#98813)
Browse files Browse the repository at this point in the history
This is a quick fix/hack to get around with the issue that some
"global" tensor view operation is invalid, but somehow it get
triggered by some models as mini-batch input itself won't have this
issue.

Since ultimately we should remove the dtensor expand and use the new
expansion, this hack is only temporary to unblock
Pull Request resolved: pytorch#98813
Approved by: https://github.com/yifuwang, https://github.com/mrshenli
  • Loading branch information
wanchaol authored and pytorchmergebot committed Apr 11, 2023
1 parent 760967a commit 1568695
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torch/distributed/_spmd/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def _remap_arg(arg: object) -> object:

op_overload = cast(torch._ops.OpOverload, node.target)

if node.target == torch.ops.aten.view.default:
# HACK: this is a hack to get around with the fact that some
# view operations on a "global" tensor is invalid usage
# but somehow the view operation on the batch input might hit it
# so we convert the view op to reshape before calling DTensor
op_overload = torch.ops.aten.reshape.default

# run dispatch once to get the real DTensor output.
out, op_schema, output_sharding = _operator_dispatch(
op_overload,
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/_tensor/ops/view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def reshape_prop(op_schema: OpSchema) -> OutputSharding:
register_prop_rule_map(aten.squeeze.default, torch.squeeze)
register_prop_rule_map(aten.squeeze.dim, torch.squeeze)
register_prop_rule_map(aten.view.default, Tensor.view)
register_prop_rule_map(aten.reshape.default, torch.reshape)
register_prop_rule_map(aten._unsafe_view.default, Tensor.view)
register_prop_rule_map(aten.unsqueeze.default, torch.unsqueeze)
register_prop_rule_map(aten.expand.default, Tensor.expand)
Expand Down

0 comments on commit 1568695

Please sign in to comment.