Skip to content

Commit

Permalink
[BE] Apply almost all remaining flake8-comprehension checks (pytorch#…
Browse files Browse the repository at this point in the history
…94676)

Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: pytorch#94676
Approved by: https://github.com/ezyang
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Feb 12, 2023
1 parent 54c0f37 commit 67d9790
Show file tree
Hide file tree
Showing 113 changed files with 500 additions and 526 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ignore =
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C405,C407
C407
per-file-ignores =
__init__.py: F401
torch/utils/cpp_extension.py: B950
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/distributed/ddp/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
ja = load(args.file[0])
jb = load(args.file[1])

keys = (set(ja.keys()) | set(jb.keys())) - set(["benchmark_results"])
keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"}
print("{:20s} {:>20s} {:>20s}".format("", "baseline", "test"))
print("{:20s} {:>20s} {:>20s}".format("", "-" * 20, "-" * 20))
for key in sorted(keys):
Expand Down
2 changes: 1 addition & 1 deletion scripts/release_notes/namespace_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_content(submod):
return content

def namespace_filter(data):
out = set(d for d in data if d[0] != "_")
out = {d for d in data if d[0] != "_"}
return out

def run(args, submod):
Expand Down
2 changes: 1 addition & 1 deletion test/ao/sparsity/test_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_mask_squash(self):
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present

def test_sparsity_levels(self):
nearliness_levels = list(nearliness for nearliness in range(-1, 100))
nearliness_levels = list(range(-1, 100))
model = nn.Sequential()

p = re.compile(r'[-\.\s]')
Expand Down
12 changes: 6 additions & 6 deletions test/distributed/fsdp/test_fsdp_ignored_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def _test_diff_ignored_modules_across_ranks(
{"ignored_modules": layer1_ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in layer1_ignored_modules for p in m.parameters()
)
}
}
)
model.layer1 = FSDP(model.layer1, **ignore_kwargs)
Expand All @@ -260,9 +260,9 @@ def _test_diff_ignored_modules_across_ranks(
{"ignored_modules": model_ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in model_ignored_modules for p in m.parameters()
)
}
}
)
wrapped_model = FSDP(model, **ignore_kwargs_top)
Expand All @@ -279,9 +279,9 @@ def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool):
{"ignored_modules": ignored_modules}
if ignore_modules
else {
"ignored_parameters": set(
"ignored_parameters": {
p for m in ignored_modules for p in m.parameters()
)
}
}
)

Expand Down
4 changes: 1 addition & 3 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,7 @@ def test_state_dict_save_load_flow(self, state_dict_type):
def test_fsdp_state_dict_keys(self, state_dict_type):
state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
if state_dict_type == "local_state_dict":
self.assertEqual(
set([FLAT_PARAM, f"inner.{FLAT_PARAM}"]), state_dict.keys()
)
self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())
elif state_dict_type in ("state_dict", "sharded_state_dict"):
# Keys should match local model.
local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class SomeDataClass:
# create a mixed bag of data.
data = [1, "str"]
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
data.insert(0, {"x", get_a_tensor(), get_a_tensor()})
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2}))
data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])})
od = OrderedDict()
od["k"] = "value"
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/pipeline/sync/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def test_named_children(setup_rpc):
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model)

names = set(n for n, _ in model.named_modules())
names = {n for n, _ in model.named_modules()}
assert "partitions.0.0" in names
assert "partitions.1.0" in names

Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def _test_sequence_num_incremented_default_group(self, backend_name):
)
self._test_sequence_num_incremented(
c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
ranks=list(range(dist.get_world_size())),
)

def _test_sequence_num_incremented_subgroup(self, backend_name):
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,9 +2296,9 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank):
# The tensors to pass to broadcast are identical to the target
# only on the process that is the root of the broadcast.
if self.rank == root_rank:
tensors = list(tensor.clone() for tensor in target)
tensors = [tensor.clone() for tensor in target]
else:
tensors = list(torch.zeros_like(tensor) for tensor in target)
tensors = [torch.zeros_like(tensor) for tensor in target]

if self.rank != root_rank:
self.assertNotEqual(tensors, target)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,9 +2623,9 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank):
# The tensors to pass to broadcast are idential to the target
# only on the process that is the root of the broadcast.
if self.rank == root_rank:
tensors = list(tensor.clone() for tensor in target)
tensors = [tensor.clone() for tensor in target]
else:
tensors = list(torch.zeros_like(tensor) for tensor in target)
tensors = [torch.zeros_like(tensor) for tensor in target]

if self.rank != root_rank:
self.assertNotEqual(tensors, target)
Expand Down
16 changes: 7 additions & 9 deletions test/dynamo/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,13 @@ class OptimizerTests(torch._dynamo.test_case.TestCase):

# exclude SparseAdam because other areas of the stack don't support it yet
# the others are handled specially above
exclude = set(
[
"SGD", # Handled above
"Optimizer",
"SparseAdam", # Unsupported
"LBFGS", # Unsupported
"RAdam", # Has data dependent control for rectification (needs symint)
]
)
exclude = {
"SGD", # Handled above
"Optimizer",
"SparseAdam", # Unsupported
"LBFGS", # Unsupported
"RAdam", # Has data dependent control for rectification (needs symint)
}

optimizers = [
opt
Expand Down
4 changes: 3 additions & 1 deletion test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,9 @@ def _get_min_chunk_len(config):
return config.lsh_attn_chunk_length
elif len(attn_types_set) == 1 and attn_types[0] == "local":
return config.local_attn_chunk_length
elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]):
elif len(attn_types_set) == 2 and attn_types_set == set( # noqa: C405
["lsh", "local"]
):
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
else:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion test/functorch/discover_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def all(cls):
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
result = {}
for key in filter:
result[key] = set([])
result[key] = set()
for op in self.data:
support_status = operator_method(op)
if support_status in filter:
Expand Down
8 changes: 4 additions & 4 deletions test/jit/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,20 @@ def fn(x):
return x.{}
"""

EQUALITY_MISMATCH = set([
EQUALITY_MISMATCH = {
# TorchScript doesn't have real enums so they return an int instead
# of the actual value
'dtype',
'layout',
])
MISSING_PROPERTIES = set([
}
MISSING_PROPERTIES = {
'grad_fn',
# This is an undocumented property so it's not included
"output_nr",
# This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways
'names',
])
}

for p in properties:
if p in MISSING_PROPERTIES:
Expand Down
2 changes: 1 addition & 1 deletion test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def specialized_list():
li.append(3)
return li

self.assertTrue(set(specialized_list()) == set([1, 2, 3]))
self.assertTrue(set(specialized_list()) == {1, 2, 3})

@skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
def test_values(self):
Expand Down
4 changes: 2 additions & 2 deletions test/jit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):

torch._C._enable_mobile_interface_call_export()
scripted_M_mod = torch.jit.script(M())
self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset(
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))

scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))

def test_math_inf(self):
Expand Down
4 changes: 2 additions & 2 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def forward(self, x):
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
set(name for name, _ in m.named_modules()),
set(name for name, _ in m_loaded.named_modules()),
{name for name, _ in m.named_modules()},
{name for name, _ in m_loaded.named_modules()},
)
# Check parameters.
m_params = dict(m.named_parameters())
Expand Down
2 changes: 1 addition & 1 deletion test/jit/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def tuple_slice(a):
self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
tuple_graph = scripted_fn.graph
slices = tuple_graph.findAllNodes("prim::TupleConstruct")
num_outputs = set(len(x.output().type().elements()) for x in slices)
num_outputs = {len(x.output().type().elements()) for x in slices}
# there should be only one tupleSlice with length of 2
self.assertTrue(num_outputs == {2})
self.run_pass('lower_all_tuples', tuple_graph)
Expand Down
18 changes: 9 additions & 9 deletions test/lazy/test_ts_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def init_lists():
yaml_ts = yaml.load(f, yaml.Loader)
LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"])))
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
FALLBACK_LIST = set(["clamp"])
SKIP_RUNTIME_ERROR_LIST = set([
FALLBACK_LIST = {"clamp"}
SKIP_RUNTIME_ERROR_LIST = {
'index_select', # Empty output_sizes is not supported
'clone', # is clone decomposed?

Expand All @@ -46,19 +46,19 @@ def init_lists():
'all', # ASAN failure
'any', # ASAN failure
'logdet', # ASAN failure
])
SKIP_INCORRECT_RESULTS_LIST = set([
}
SKIP_INCORRECT_RESULTS_LIST = {
'squeeze', # Value out of range
't', # Value out of range
'transpose', # Value out of range
'bernoulli', # incorrect results
'pow', # incorrect results
'addcdiv', # incorrect results (on CI not locally?)
])
}
# The following ops all show up directly in ts_native_functions.yaml,
# but run functionalized versions of the composite kernels in core.
# This means that we don't expect the ops to show directly in the LTC metrics.
FUNCTIONAL_DECOMPOSE_LIST = set([
FUNCTIONAL_DECOMPOSE_LIST = {
'diag_embed',
'block_diag',
'new_empty_strided',
Expand All @@ -70,13 +70,13 @@ def init_lists():
'linalg_inv_ex',
'linalg_pinv.atol_rtol_tensor',
'logsumexp',
])
}
# For some ops, we don't support all variants. Here we use formatted_name
# to uniquely identify the variant.
SKIP_VARIANT_LIST = set([
SKIP_VARIANT_LIST = {
'norm_nuc',
'min_reduction_with_dim'
])
}

return (LAZY_OPS_LIST,
FALLBACK_LIST,
Expand Down
8 changes: 4 additions & 4 deletions test/package/test_dependency_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def my_extern_hook(package_exporter, module_name):
exporter.register_extern_hook(my_extern_hook)
exporter.save_source_string("foo", "import module_a")

self.assertEqual(my_externs, set(["module_a"]))
self.assertEqual(my_externs, {"module_a"})

def test_multiple_extern_hooks(self):
buffer = BytesIO()
Expand Down Expand Up @@ -93,7 +93,7 @@ def my_extern_hook2(package_exporter, module_name):
exporter.save_source_string("foo", "import module_a")

self.assertEqual(my_externs, set())
self.assertEqual(my_externs2, set(["module_a"]))
self.assertEqual(my_externs2, {"module_a"})

def test_extern_and_mock_hook(self):
buffer = BytesIO()
Expand All @@ -114,8 +114,8 @@ def my_mock_hook(package_exporter, module_name):
exporter.register_mock_hook(my_mock_hook)
exporter.save_source_string("foo", "import module_a; import package_a")

self.assertEqual(my_externs, set(["module_a"]))
self.assertEqual(my_mocks, set(["package_a"]))
self.assertEqual(my_externs, {"module_a"})
self.assertEqual(my_mocks, {"package_a"})


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions test/package/test_digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_iter(self):
for n in g:
nodes.add(n)

self.assertEqual(nodes, set([1, 2, 3]))
self.assertEqual(nodes, {1, 2, 3})

def test_contains(self):
g = DiGraph()
Expand All @@ -101,8 +101,8 @@ def test_forward_closure(self):
g.add_edge("2", "3")
g.add_edge("5", "4")
g.add_edge("4", "3")
self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"]))
self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"]))
self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})

def test_all_paths(self):
g = DiGraph()
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,7 +2443,7 @@ def test_instance_norm(self):
affine_list = (True, False)
combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list]
test_cases_product = itertools.product(*combined)
test_cases = list(test_case for test_case in test_cases_product)
test_cases = list(test_cases_product)
# add just one test case to test overflow
test_cases.append([
[1, 4, 224, 224, 160], # shape,
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/eager/test_model_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_weight_only_activation_only_fakequant(self):
torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = set([torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig])
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig}
SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset):
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/eager/test_quantize_eager_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def checkQuantized(model):

# Test set qconfig
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype)
checkQuantized(model)

def test_two_layers(self):
Expand Down
Loading

0 comments on commit 67d9790

Please sign in to comment.