Skip to content

Commit

Permalink
Sync .pyre_configuration.external files for those that got out of syn…
Browse files Browse the repository at this point in the history
…c with last upgrade (pytorch#2294)

Summary:
Pull Request resolved: pytorch#2294

X-link: pytorch/torchx#943

This downgrades the .pyre_configurations in the following projects to keep their internal and external configurations synced.
- beanmachine/beanmachine/ppl
- tools/sapp
- torchrec
- torchx

The versions were chosen to match the most recent Pyre version in [pyre-check-nightly](https://pypi.org/project/pyre-check-nightly/#history).

Differential Revision: D61211101
  • Loading branch information
connernilsen authored and facebook-github-bot committed Aug 15, 2024
1 parent 924d393 commit 6b7ca1b
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 19 deletions.
9 changes: 6 additions & 3 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:

sharded_module = _shard_modules(
module=copied_module,
# pyre-ignore [6]
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[ModuleSharder[Variable[T (bound to Module)]]]`.
sharders=[sharder],
device=device,
plan=plan,
Expand All @@ -489,13 +491,14 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:

if compile_mode == CompileMode.FX_SCRIPT:
return fx_script_module(
# pyre-ignore [6]
# pyre-fixme[6]: For 1st argument expected `Module` but got
# `Optional[Module]`.
sharded_module
if not benchmark_unsharded_module
else module
)
else:
# pyre-ignore [7]
# pyre-fixme[7]: Expected `Module` but got `Optional[Module]`.
return sharded_module if not benchmark_unsharded_module else module


Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/composable/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _test_sharding( # noqa C901
# pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
env=ShardingEnv.from_process_group(ctx.pg),
# pyre-ignore
# pyre-fixme[6]: For 4th argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[EmbeddingCollectionSharder]`.
sharders=[sharder],
device=ctx.device,
)
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,8 @@ def _update_local(
raise NotImplementedError("Inference does not support update")

def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor:
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value
# of `None`.
pass


Expand Down
8 changes: 5 additions & 3 deletions torchrec/distributed/object_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def input_dist(
# pyre-ignore[2]
**kwargs,
) -> Awaitable[Awaitable[torch.Tensor]]:
# pyre-ignore
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit
# return value of `None`.
pass

def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
# pyre-ignore
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
pass

def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-ignore
# pyre-fixme[7]: Expected `LazyAwaitable[Variable[Out]]` but got implicit
# return value of `None`.
pass
4 changes: 3 additions & 1 deletion torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
aten = torch.ops.aten # pyre-ignore[5]


class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
class LocalShardsWrapper(torch.Tensor):
"""
A wrapper class to hold local shards of a DTensor.
This class is used largely for checkpointing purposes and implicity subtypes
Expand Down
6 changes: 4 additions & 2 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,8 @@ def shard_qebc(
quant_model_copy = copy.deepcopy(mi.quant_model)
sharded_model = _shard_modules(
module=quant_model_copy,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got `List[TestQuantEBCSharder]`.
sharders=[sharder],
device=device,
plan=plan,
Expand Down Expand Up @@ -912,7 +913,8 @@ def shard_qec(
quant_model_copy = copy.deepcopy(mi.quant_model)
sharded_model = _shard_modules(
module=quant_model_copy,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got `List[TestQuantECSharder]`.
sharders=[sharder],
device=device,
plan=plan,
Expand Down
8 changes: 6 additions & 2 deletions torchrec/distributed/tests/test_infer_hetero_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def test_sharder_different_world_sizes_for_qec(self, sharding_device: str) -> No

sharded_model = _shard_modules(
module=non_sharded_model,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[QuantEmbeddingCollectionSharder]`.
sharders=[sharder],
device=torch.device(sharding_device),
plan=plan,
Expand Down Expand Up @@ -201,7 +203,9 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None:
}
sharded_model = _shard_modules(
module=non_sharded_model,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[QuantEmbeddingBagCollectionSharder]`.
sharders=[sharder],
device=torch.device("cpu"),
plan=plan,
Expand Down
8 changes: 6 additions & 2 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,7 +2025,9 @@ def test_sharded_quant_fp_ebc_tw(

sharded_model = _shard_modules(
module=quant_model,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[QuantFeatureProcessedEmbeddingBagCollectionSharder]`.
sharders=[sharder],
device=local_device,
plan=plan,
Expand Down Expand Up @@ -2180,7 +2182,9 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:

sharded_model = _shard_modules(
module=quant_model,
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[QuantFeatureProcessedEmbeddingBagCollectionSharder]`.
sharders=[sharder],
# shard on meta to simulate device movement from cpu -> meta QFPEBC
device=torch.device("meta"),
Expand Down
12 changes: 9 additions & 3 deletions torchrec/distributed/tests/test_infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def test_get_tbe_specs_from_sqebc(self) -> None:

sharded_model = _shard_modules(
module=quant_model[0],
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[TestQuantEBCSharder]`.
sharders=[sharder],
device=device,
plan=plan,
Expand Down Expand Up @@ -178,7 +180,9 @@ def test_get_tbe_specs_from_sqec(self) -> None:

sharded_model = _shard_modules(
module=quant_model[0],
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[TestQuantECSharder]`.
sharders=[sharder],
device=device,
plan=plan,
Expand Down Expand Up @@ -256,7 +260,9 @@ def test_get_all_torchrec_modules_for_single_module(self) -> None:

sharded_model = _shard_modules(
module=quant_model[0],
# pyre-ignore
# pyre-fixme[6]: For 2nd argument expected
# `Optional[List[ModuleSharder[Module]]]` but got
# `List[TestQuantEBCSharder]`.
sharders=[sharder],
device=device,
plan=plan,
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
# other metaclasses (i.e. AwaitableMeta) for customized
# behaviors, as Generic is non-trival metaclass in
# python 3.6 and below
from typing import GenericMeta # pyre-ignore: python 3.6
# pyre-fixme[21]: Could not find name `GenericMeta` in `typing` (stubbed).
from typing import GenericMeta
except ImportError:
# In python 3.7+, GenericMeta doesn't exist as it's no
# longer a non-trival metaclass,
Expand Down

0 comments on commit 6b7ca1b

Please sign in to comment.