diff --git a/.flake8 b/.flake8 index d6e1aa0e36618..3f8cdcc4c541e 100644 --- a/.flake8 +++ b/.flake8 @@ -11,7 +11,7 @@ ignore = # these ignores are from flake8-bugbear; please fix! B007,B008, # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C405,C407 + C407 per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 diff --git a/benchmarks/distributed/ddp/diff.py b/benchmarks/distributed/ddp/diff.py index dc984626888a0..d427a5b29d919 100644 --- a/benchmarks/distributed/ddp/diff.py +++ b/benchmarks/distributed/ddp/diff.py @@ -25,7 +25,7 @@ def main(): ja = load(args.file[0]) jb = load(args.file[1]) - keys = (set(ja.keys()) | set(jb.keys())) - set(["benchmark_results"]) + keys = (set(ja.keys()) | set(jb.keys())) - {"benchmark_results"} print("{:20s} {:>20s} {:>20s}".format("", "baseline", "test")) print("{:20s} {:>20s} {:>20s}".format("", "-" * 20, "-" * 20)) for key in sorted(keys): diff --git a/scripts/release_notes/namespace_check.py b/scripts/release_notes/namespace_check.py index 54196bdfbe6f1..1b9a91c12f8a4 100644 --- a/scripts/release_notes/namespace_check.py +++ b/scripts/release_notes/namespace_check.py @@ -39,7 +39,7 @@ def get_content(submod): return content def namespace_filter(data): - out = set(d for d in data if d[0] != "_") + out = {d for d in data if d[0] != "_"} return out def run(args, submod): diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index 512c58b188367..582f12fe4861b 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -417,7 +417,7 @@ def test_mask_squash(self): assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present def test_sparsity_levels(self): - nearliness_levels = list(nearliness for nearliness in range(-1, 100)) + nearliness_levels = list(range(-1, 100)) model = nn.Sequential() p = re.compile(r'[-\.\s]') diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index 3676acdbda549..d93a923f5f790 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -244,9 +244,9 @@ def _test_diff_ignored_modules_across_ranks( {"ignored_modules": layer1_ignored_modules} if ignore_modules else { - "ignored_parameters": set( + "ignored_parameters": { p for m in layer1_ignored_modules for p in m.parameters() - ) + } } ) model.layer1 = FSDP(model.layer1, **ignore_kwargs) @@ -260,9 +260,9 @@ def _test_diff_ignored_modules_across_ranks( {"ignored_modules": model_ignored_modules} if ignore_modules else { - "ignored_parameters": set( + "ignored_parameters": { p for m in model_ignored_modules for p in m.parameters() - ) + } } ) wrapped_model = FSDP(model, **ignore_kwargs_top) @@ -279,9 +279,9 @@ def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool): {"ignored_modules": ignored_modules} if ignore_modules else { - "ignored_parameters": set( + "ignored_parameters": { p for m in ignored_modules for p in m.parameters() - ) + } } ) diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 62d3da621ffa3..ddb960e3dc81d 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -783,9 +783,7 @@ def test_state_dict_save_load_flow(self, state_dict_type): def test_fsdp_state_dict_keys(self, state_dict_type): state_dict = self._state_dict(self._initialize_model(True), state_dict_type) if state_dict_type == "local_state_dict": - self.assertEqual( - set([FLAT_PARAM, f"inner.{FLAT_PARAM}"]), state_dict.keys() - ) + self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys()) elif state_dict_type in ("state_dict", "sharded_state_dict"): # Keys should match local model. local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index 249fb5326f21f..45b78148eb2ed 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -66,8 +66,8 @@ class SomeDataClass: # create a mixed bag of data. data = [1, "str"] data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) - data.insert(0, set(["x", get_a_tensor(), get_a_tensor()])) - data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2)))) + data.insert(0, {"x", get_a_tensor(), get_a_tensor()}) + data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2})) data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])}) od = OrderedDict() od["k"] = "value" diff --git a/test/distributed/pipeline/sync/test_pipe.py b/test/distributed/pipeline/sync/test_pipe.py index abfa738603a1f..cce106919159e 100644 --- a/test/distributed/pipeline/sync/test_pipe.py +++ b/test/distributed/pipeline/sync/test_pipe.py @@ -662,7 +662,7 @@ def test_named_children(setup_rpc): model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) model = Pipe(model) - names = set(n for n, _ in model.named_modules()) + names = {n for n, _ in model.named_modules()} assert "partitions.0.0" in names assert "partitions.1.0" in names diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index de0d8e7c25a68..87c804acd9b16 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1120,7 +1120,7 @@ def _test_sequence_num_incremented_default_group(self, backend_name): ) self._test_sequence_num_incremented( c10d._get_default_group(), - ranks=list(i for i in range(dist.get_world_size())), + ranks=list(range(dist.get_world_size())), ) def _test_sequence_num_incremented_subgroup(self, backend_name): diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 2b5f3f4a9465f..dfdfe442ab44d 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2296,9 +2296,9 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank): # The tensors to pass to broadcast are identical to the target # only on the process that is the root of the broadcast. if self.rank == root_rank: - tensors = list(tensor.clone() for tensor in target) + tensors = [tensor.clone() for tensor in target] else: - tensors = list(torch.zeros_like(tensor) for tensor in target) + tensors = [torch.zeros_like(tensor) for tensor in target] if self.rank != root_rank: self.assertNotEqual(tensors, target) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d1ecdba6da177..a1c7ad28a0d17 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2623,9 +2623,9 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank): # The tensors to pass to broadcast are idential to the target # only on the process that is the root of the broadcast. if self.rank == root_rank: - tensors = list(tensor.clone() for tensor in target) + tensors = [tensor.clone() for tensor in target] else: - tensors = list(torch.zeros_like(tensor) for tensor in target) + tensors = [torch.zeros_like(tensor) for tensor in target] if self.rank != root_rank: self.assertNotEqual(tensors, target) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 8e51ec5daf3f8..b8b5f99740b58 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -55,15 +55,13 @@ class OptimizerTests(torch._dynamo.test_case.TestCase): # exclude SparseAdam because other areas of the stack don't support it yet # the others are handled specially above -exclude = set( - [ - "SGD", # Handled above - "Optimizer", - "SparseAdam", # Unsupported - "LBFGS", # Unsupported - "RAdam", # Has data dependent control for rectification (needs symint) - ] -) +exclude = { + "SGD", # Handled above + "Optimizer", + "SparseAdam", # Unsupported + "LBFGS", # Unsupported + "RAdam", # Has data dependent control for rectification (needs symint) +} optimizers = [ opt diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 7e8477d673c52..d20305513c152 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -649,7 +649,9 @@ def _get_min_chunk_len(config): return config.lsh_attn_chunk_length elif len(attn_types_set) == 1 and attn_types[0] == "local": return config.local_attn_chunk_length - elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): + elif len(attn_types_set) == 2 and attn_types_set == set( # noqa: C405 + ["lsh", "local"] + ): return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) else: raise NotImplementedError( diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 6d1e055d01f21..aafa179bc81bd 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -803,7 +803,7 @@ def all(cls): def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)): result = {} for key in filter: - result[key] = set([]) + result[key] = set() for op in self.data: support_status = operator_method(op) if support_status in filter: diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index e3670aa798721..aa78a976be587 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -158,20 +158,20 @@ def fn(x): return x.{} """ - EQUALITY_MISMATCH = set([ + EQUALITY_MISMATCH = { # TorchScript doesn't have real enums so they return an int instead # of the actual value 'dtype', 'layout', - ]) - MISSING_PROPERTIES = set([ + } + MISSING_PROPERTIES = { 'grad_fn', # This is an undocumented property so it's not included "output_nr", # This has a longer implementation, maybe not worth copying to # TorchScript if named tensors don't work there anyways 'names', - ]) + } for p in properties: if p in MISSING_PROPERTIES: diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 980b76cf59978..3fdce7e1a6586 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -1516,7 +1516,7 @@ def specialized_list(): li.append(3) return li - self.assertTrue(set(specialized_list()) == set([1, 2, 3])) + self.assertTrue(set(specialized_list()) == {1, 2, 3}) @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") def test_values(self): diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index 2c2bf2ceb6919..d4bca3da64714 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -221,11 +221,11 @@ def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): torch._C._enable_mobile_interface_call_export() scripted_M_mod = torch.jit.script(M()) - self.assertTrue(set(['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal']).issubset( + self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset( set(torch.jit.export_opnames(scripted_M_mod)))) scripted_M_mod.sub = torch.jit.script(FooMod()) - self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset( + self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset( set(torch.jit.export_opnames(scripted_M_mod)))) def test_math_inf(self): diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 81a24f6680231..6f32bc96dc496 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -525,8 +525,8 @@ def forward(self, x): len(list(m.named_modules())), len(list(m_loaded.named_modules())) ) self.assertEqual( - set(name for name, _ in m.named_modules()), - set(name for name, _ in m_loaded.named_modules()), + {name for name, _ in m.named_modules()}, + {name for name, _ in m_loaded.named_modules()}, ) # Check parameters. m_params = dict(m.named_parameters()) diff --git a/test/jit/test_slice.py b/test/jit/test_slice.py index 5878f6c43bf21..ceb3c3b48e89d 100644 --- a/test/jit/test_slice.py +++ b/test/jit/test_slice.py @@ -133,7 +133,7 @@ def tuple_slice(a): self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) tuple_graph = scripted_fn.graph slices = tuple_graph.findAllNodes("prim::TupleConstruct") - num_outputs = set(len(x.output().type().elements()) for x in slices) + num_outputs = {len(x.output().type().elements()) for x in slices} # there should be only one tupleSlice with length of 2 self.assertTrue(num_outputs == {2}) self.run_pass('lower_all_tuples', tuple_graph) diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 092ba3d0388d0..070d97af189dd 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -34,8 +34,8 @@ def init_lists(): yaml_ts = yaml.load(f, yaml.Loader) LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"]))) HAS_SYMINT_SUFFIX = yaml_ts["symint"] - FALLBACK_LIST = set(["clamp"]) - SKIP_RUNTIME_ERROR_LIST = set([ + FALLBACK_LIST = {"clamp"} + SKIP_RUNTIME_ERROR_LIST = { 'index_select', # Empty output_sizes is not supported 'clone', # is clone decomposed? @@ -46,19 +46,19 @@ def init_lists(): 'all', # ASAN failure 'any', # ASAN failure 'logdet', # ASAN failure - ]) - SKIP_INCORRECT_RESULTS_LIST = set([ + } + SKIP_INCORRECT_RESULTS_LIST = { 'squeeze', # Value out of range 't', # Value out of range 'transpose', # Value out of range 'bernoulli', # incorrect results 'pow', # incorrect results 'addcdiv', # incorrect results (on CI not locally?) - ]) + } # The following ops all show up directly in ts_native_functions.yaml, # but run functionalized versions of the composite kernels in core. # This means that we don't expect the ops to show directly in the LTC metrics. - FUNCTIONAL_DECOMPOSE_LIST = set([ + FUNCTIONAL_DECOMPOSE_LIST = { 'diag_embed', 'block_diag', 'new_empty_strided', @@ -70,13 +70,13 @@ def init_lists(): 'linalg_inv_ex', 'linalg_pinv.atol_rtol_tensor', 'logsumexp', - ]) + } # For some ops, we don't support all variants. Here we use formatted_name # to uniquely identify the variant. - SKIP_VARIANT_LIST = set([ + SKIP_VARIANT_LIST = { 'norm_nuc', 'min_reduction_with_dim' - ]) + } return (LAZY_OPS_LIST, FALLBACK_LIST, diff --git a/test/package/test_dependency_hooks.py b/test/package/test_dependency_hooks.py index df155ab1dea30..a4824f9a42e34 100644 --- a/test/package/test_dependency_hooks.py +++ b/test/package/test_dependency_hooks.py @@ -31,7 +31,7 @@ def my_extern_hook(package_exporter, module_name): exporter.register_extern_hook(my_extern_hook) exporter.save_source_string("foo", "import module_a") - self.assertEqual(my_externs, set(["module_a"])) + self.assertEqual(my_externs, {"module_a"}) def test_multiple_extern_hooks(self): buffer = BytesIO() @@ -93,7 +93,7 @@ def my_extern_hook2(package_exporter, module_name): exporter.save_source_string("foo", "import module_a") self.assertEqual(my_externs, set()) - self.assertEqual(my_externs2, set(["module_a"])) + self.assertEqual(my_externs2, {"module_a"}) def test_extern_and_mock_hook(self): buffer = BytesIO() @@ -114,8 +114,8 @@ def my_mock_hook(package_exporter, module_name): exporter.register_mock_hook(my_mock_hook) exporter.save_source_string("foo", "import module_a; import package_a") - self.assertEqual(my_externs, set(["module_a"])) - self.assertEqual(my_mocks, set(["package_a"])) + self.assertEqual(my_externs, {"module_a"}) + self.assertEqual(my_mocks, {"package_a"}) if __name__ == "__main__": diff --git a/test/package/test_digraph.py b/test/package/test_digraph.py index 92f469868f7c9..90dc11f3a100f 100644 --- a/test/package/test_digraph.py +++ b/test/package/test_digraph.py @@ -82,7 +82,7 @@ def test_iter(self): for n in g: nodes.add(n) - self.assertEqual(nodes, set([1, 2, 3])) + self.assertEqual(nodes, {1, 2, 3}) def test_contains(self): g = DiGraph() @@ -101,8 +101,8 @@ def test_forward_closure(self): g.add_edge("2", "3") g.add_edge("5", "4") g.add_edge("4", "3") - self.assertTrue(g.forward_transitive_closure("1") == set(["1", "2", "3"])) - self.assertTrue(g.forward_transitive_closure("4") == set(["4", "3"])) + self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"}) + self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"}) def test_all_paths(self): g = DiGraph() diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 1ec22594d3797..1d38d39df85e0 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2443,7 +2443,7 @@ def test_instance_norm(self): affine_list = (True, False) combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list] test_cases_product = itertools.product(*combined) - test_cases = list(test_case for test_case in test_cases_product) + test_cases = list(test_cases_product) # add just one test case to test overflow test_cases.append([ [1, 4, 224, 224, 160], # shape, diff --git a/test/quantization/eager/test_model_numerics.py b/test/quantization/eager/test_model_numerics.py index bcefb78bd7529..1a1ef3b917fc1 100644 --- a/test/quantization/eager/test_model_numerics.py +++ b/test/quantization/eager/test_model_numerics.py @@ -95,8 +95,8 @@ def test_weight_only_activation_only_fakequant(self): torch.manual_seed(67) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) - qconfigset = set([torch.ao.quantization.default_weight_only_qconfig, - torch.ao.quantization.default_activation_only_qconfig]) + qconfigset = {torch.ao.quantization.default_weight_only_qconfig, + torch.ao.quantization.default_activation_only_qconfig} SQNRTarget = [35, 45] for idx, qconfig in enumerate(qconfigset): my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index a20a17d6637df..3b878b7ec7573 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -1120,7 +1120,7 @@ def checkQuantized(model): # Test set qconfig model = SingleLayerLinearDynamicModel() - quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype) + quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype) checkQuantized(model) def test_two_layers(self): diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index e0a428a987b53..24bb7c44eef5a 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -895,7 +895,7 @@ def test_constructor(self): model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0]) # make an example set of detectors - test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)]) + test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)} # initialize with an empty detector model_report = ModelReport(model_prep, test_detector_set) @@ -905,7 +905,7 @@ def test_constructor(self): # now attempt with no valid reports, should raise error with self.assertRaises(ValueError): - model_report = ModelReport(model, set([])) + model_report = ModelReport(model, set()) # number of expected obs of interest entries num_expected_entries = len(test_detector_set) @@ -932,7 +932,7 @@ def test_prepare_model_callibration(self): # make an example set of detectors torch.backends.quantized.engine = "fbgemm" backend = torch.backends.quantized.engine - test_detector_set = set([DynamicStaticDetector(), PerChannelDetector(backend)]) + test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)} # initialize with an empty detector # prepare the model @@ -1029,8 +1029,8 @@ def test_generate_report(self): torch.backends.quantized.engine = "fbgemm" # check whether the correct number of reports are being generated - filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)]) - single_detector_set = set([DynamicStaticDetector()]) + filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)} + single_detector_set = {DynamicStaticDetector()} # create our models model_full = TwoThreeOps() @@ -1316,7 +1316,7 @@ def test_input_weight_equalization_determine_points(self): # then create model report instance with detector with override_quantized_engine('fbgemm'): - detector_set = set([InputWeightEqualizationDetector(0.5)]) + detector_set = {InputWeightEqualizationDetector(0.5)} # get tst model and callibrate non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set) @@ -1326,7 +1326,7 @@ def test_input_weight_equalization_determine_points(self): for prepared_for_callibrate_model, mod_report in [non_fused, fused]: # supported modules to check - mods_to_check = set([nn.Linear, nn.Conv2d]) + mods_to_check = {nn.Linear, nn.Conv2d} # get the set of all nodes in the graph their fqns node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes} @@ -1362,7 +1362,7 @@ def test_input_weight_equalization_report_gen(self): with override_quantized_engine('fbgemm'): test_input_weight_detector = InputWeightEqualizationDetector(0.4) - detector_set = set([test_input_weight_detector]) + detector_set = {test_input_weight_detector} model = self.TwoBlockComplexNet() # prepare the model for callibration prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model( @@ -1471,7 +1471,7 @@ def test_input_weight_equalization_report_gen_empty(self): # then create model report instance with detector with override_quantized_engine('fbgemm'): test_input_weight_detector = InputWeightEqualizationDetector(0.4) - detector_set = set([test_input_weight_detector]) + detector_set = {test_input_weight_detector} model = self.ReluOnly() # prepare the model for callibration prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set) @@ -1547,7 +1547,7 @@ def test_outlier_detection_determine_points(self): # not explicitly testing fusion because fx workflow automatically with override_quantized_engine('fbgemm'): - detector_set = set([OutlierDetector(reference_percentile=0.95)]) + detector_set = {OutlierDetector(reference_percentile=0.95)} # get tst model and callibrate prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( @@ -1555,7 +1555,7 @@ def test_outlier_detection_determine_points(self): ) # supported modules to check - mods_to_check = set([nn.Linear, nn.Conv2d, nn.ReLU]) + mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU} # there should be 4 node fqns that have the observer inserted correct_number_of_obs_inserted = 4 @@ -1590,7 +1590,7 @@ def test_no_outlier_report_gen(self): dynamic_static_detector = DynamicStaticDetector(tolerance=0.5) param_size: int = 4 - detector_set = set([outlier_detector, dynamic_static_detector]) + detector_set = {outlier_detector, dynamic_static_detector} model = self.LargeBatchModel(param_size=param_size) # get tst model and callibrate @@ -1640,7 +1640,7 @@ def test_all_outlier_report_gen(self): outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0) param_size: int = 16 - detector_set = set([outlier_detector]) + detector_set = {outlier_detector} model = self.LargeBatchModel(param_size=param_size) # get tst model and callibrate @@ -1690,7 +1690,7 @@ def test_multiple_run_consistent_spike_outlier_report_gen(self): outlier_detector = OutlierDetector(reference_percentile=0.95) param_size: int = 8 - detector_set = set([outlier_detector]) + detector_set = {outlier_detector} model = self.LargeBatchModel(param_size=param_size) # get tst model and callibrate @@ -1874,8 +1874,8 @@ def test_generate_tables_match_with_report(self): channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] # these two together should be the same as the generated report info in terms of keys - tensor_info_modules = set(row[1] for row in tensor_table) - channel_info_modules = set(row[1] for row in channel_table) + tensor_info_modules = {row[1] for row in tensor_table} + channel_info_modules = {row[1] for row in channel_table} combined_modules: Set = tensor_info_modules.union(channel_info_modules) generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys()) @@ -1901,8 +1901,8 @@ def test_generate_tables_no_match(self): tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] - tensor_info_modules = set(row[1] for row in tensor_table) - channel_info_modules = set(row[1] for row in channel_table) + tensor_info_modules = {row[1] for row in tensor_table} + channel_info_modules = {row[1] for row in channel_table} combined_modules: Set = tensor_info_modules.union(channel_info_modules) self.assertEqual(len(combined_modules), 0) # should be no matching modules diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 7726dc04c7111..01fb7e9ae23dd 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -660,16 +660,16 @@ def forward(self, x): m = torch.jit.script(M()) qconfig_dict = {"": default_qconfig} m = prepare_jit(m, qconfig_dict) - activation_dtypes = set( + activation_dtypes = { obs.getattr("dtype") for x, obs in m._modules._c.items() if x.startswith("_observer_") - ) - weight_dtypes = set( + } + weight_dtypes = { obs.getattr("dtype") for x, obs in m.conv._modules._c.items() if x.startswith("_observer_") - ) + } assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype" assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype" assert ( diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 3adfef4ca1166..82113efed7b1d 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1557,7 +1557,7 @@ def test_pow_inplace_resizing_exception(self, device): ((2, 1), (2, 2)), ((2, 2), (2, 1, 1)), ) - test_inputs = list( + test_inputs = [ ( make_tensor( base_size, dtype=torch.float64, device=device, high=10.0, low=0.0 @@ -1567,7 +1567,7 @@ def test_pow_inplace_resizing_exception(self, device): ), ) for base_size, exp_size in test_cases - ) + ] for base, exponent in test_inputs: regex = "doesn't match the broadcast shape" self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) @@ -1605,10 +1605,10 @@ def test_float_scalar_pow_float_tensor(self, device, dtype): (2, 1), (2, 2, 2), ) - tensors = list( + tensors = [ make_tensor(shape, dtype=dtype, device=device, low=0) for shape in exponent_shapes - ) + ] floats_tensor = torch.tensor(floats, dtype=dtype, device=device) for base in floats: self._test_pow(base, floats_tensor) diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 0330af378746f..db3c8df9b872f 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -194,7 +194,7 @@ def foo(self, arg): # Check helper that work on all functions all_info = loaded.get_bundled_inputs_functions_and_info() - self.assertEqual(set(all_info.keys()), set(['forward', 'foo'])) + self.assertEqual(set(all_info.keys()), {'forward', 'foo'}) self.assertEqual(all_info['forward']['get_inputs_function_name'], ['get_all_bundled_inputs_for_forward']) self.assertEqual(all_info['foo']['get_inputs_function_name'], ['get_all_bundled_inputs_for_foo']) self.assertEqual(all_info['forward']['info'], info) diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 6cfe26a14f783..ac24193fcc74f 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -191,7 +191,7 @@ def check_union(self, funcs): In these cases we expect to get exactly one function per python type. """ # Verify that all functions have the same return type. - union_type = set(self.expected_return_type(f) for f in funcs) + union_type = {self.expected_return_type(f) for f in funcs} assert len(union_type) == 1 union_type = union_type.pop() self.assertIs(Union, get_origin(union_type)) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 56856748b762b..05119686d5160 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1361,7 +1361,7 @@ def test_iterable_style_dataset(self): dataloader_iter = iter(dataloader) fetched = list(dataloader_iter) self.assertEqual(len(fetched), 4) - fetched = set(tuple(t.tolist()) for t in fetched) + fetched = {tuple(t.tolist()) for t in fetched} self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))}) # [auto-batching] test that workers exit gracefully @@ -1399,7 +1399,7 @@ def test_iterable_style_dataset(self): dataloader_iter = iter(dataloader) fetched = list(dataloader_iter) self.assertEqual(len(fetched), 2) - fetched = set(tuple(t.tolist()) for t in fetched) + fetched = {tuple(t.tolist()) for t in fetched} self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))}) # [auto-batching & drop_last] test that workers exit gracefully @@ -1500,7 +1500,7 @@ def get_dataloader(): num_workers = 6 batch_size = 1 dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) - self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader())) + self.assertEqual({int(batch) for batch in get_dataloader()}, {int(batch) for batch in get_dataloader()}) def test_multi_epochs_reproducibility(self): num_workers = 2 diff --git a/test/test_datapipe.py b/test/test_datapipe.py index fbb7156677e65..59abbc28260e6 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -1755,7 +1755,7 @@ def test_zip_iterdatapipe(self): len(zipped_dp) # Functional Test: zips the results properly - exp = list((i, i) for i in range(5)) + exp = [(i, i) for i in range(5)] self.assertEqual(list(zipped_dp), exp) # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest) @@ -2364,7 +2364,7 @@ def __iter__(self) -> Iterator[T]: # Context Manager to disable the runtime validation with runtime_validation_disabled(): - self.assertEqual(list(d for d in dp3), ds) + self.assertEqual(list(dp3), ds) class NumbersDataset(IterDataPipe): diff --git a/test/test_decomp.py b/test/test_decomp.py index 221c76121ad4f..c27ffadb6123f 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -739,9 +739,9 @@ def all_aten_overloads(): # This is for operators that are only registered in some CI # configurations, so would cause the test to fail - allow_list = set([aten.get_gradients.default]) + allow_list = {aten.get_gradients.default} - overloads_wanting_decomp = set(op for op in all_aten_overloads() if can_appear_in_trace(op)) + overloads_wanting_decomp = {op for op in all_aten_overloads() if can_appear_in_trace(op)} ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() ops_missing_decomp -= allow_list self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp))) diff --git a/test/test_foreach.py b/test/test_foreach.py index 130f010a8565e..2f63e1451bad4 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -466,7 +466,7 @@ def test_binary_op_tensors_on_different_devices(self, device, dtype, op): # `tensors2`: ['cuda', 'cpu'] _cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))[0].input _cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))[0].input - tensors1, tensors2 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors)) + tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors)) foreach_op, foreach_op_ = op.method_variant, op.inplace_variant native_op, native_op_ = op.ref, op.ref_inplace @@ -494,7 +494,7 @@ def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # tensors3: ['cuda', 'cpu] _cuda_tensors = list(op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True))[0].input _cpu_tensors = list(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))[0].input - tensors1, tensors2, tensors3 = list(tensors for tensors in zip(_cuda_tensors, _cpu_tensors)) + tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors)) foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref actual = foreach_op(tensors1, tensors2, tensors3) diff --git a/test/test_fx.py b/test/test_fx.py index bc4a821f2c964..2b70c581a392f 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1598,8 +1598,8 @@ def forward(self, x): if node.op == 'output': output_shape = node.args[0].meta['tensor_meta'].shape output_stride = node.args[0].meta['tensor_meta'].stride - self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', - 'call_module', 'output'])) + self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method', + 'call_module', 'output'}) # Test shape propagation and make sure results match actual self.assertEqual(output_shape, ref_out.shape) @@ -1832,8 +1832,8 @@ def test_interpreter_gc_values(self): interp = Interpreter(symbolic_trace(rn18)) inp = torch.rand(5, 3, 224, 224) out = interp.run(inp) - env_key_names = set(n.name for n in interp.env.keys()) - self.assertEqual(env_key_names, set(['output'])) + env_key_names = {n.name for n in interp.env.keys()} + self.assertEqual(env_key_names, {'output'}) def test_interpreter_default_args(self): class Model(torch.nn.Module): @@ -2052,7 +2052,7 @@ def test_deepcopy_recursion_depth(self): for orig_node, new_node in zip(g.nodes, copied_graph.nodes): orig_users = set(orig_node.users.keys()) - orig_users_equiv = set(val_map[u] for u in orig_users) + orig_users_equiv = {val_map[u] for u in orig_users} new_users = set(new_node.users.keys()) self.assertEqual(orig_users_equiv, new_users) @@ -2230,7 +2230,7 @@ def test_find_uses(self): users_of_x = x.node.users self.assertEqual(len(users_of_x), 3) - expected_ops = set(['relu', 'add', 'neg']) + expected_ops = {'relu', 'add', 'neg'} for use in users_of_x: assert any(use.name.startswith(prefix) for prefix in expected_ops) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index f81627999722e..298ef8fec3e0c 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -873,7 +873,7 @@ def is_leaf_module( ) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) + leaves = {torch.nn.BatchNorm2d} return type(m) in leaves traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) @@ -1057,7 +1057,7 @@ def is_leaf_module( ) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) + leaves = {torch.nn.BatchNorm2d} return type(m) in leaves traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 856b883a7aece..8f9b467393c75 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -3743,7 +3743,7 @@ def find_nearest_divisor(N): result += 1 return result - complete_views = set([tuple(original_view)]) + complete_views = {tuple(original_view)} to_visit = [] # empty new view, curent originaal view, start pos=0, move count = 0, last_move diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 2dd2598f831cd..ebdd2eefaa37c 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -27,7 +27,7 @@ def strip_profiling_nodes(nodes): - profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut']) + profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} return [n for n in nodes if n.kind() not in profiling_opcodes] diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 9b1e30f27a7ea..08e2911115f27 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -46,7 +46,7 @@ autograd_check_set = {'aten::__is__', 'prim::AutogradAllNonZero', 'prim::AutogradAllZero', 'prim::ListConstruct'} def strip_profiling_nodes(nodes): - profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut']) + profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} return [n for n in nodes if n.kind() not in profiling_opcodes] def warmup_forward(f, *args, profiling_count=2): @@ -189,7 +189,7 @@ def func(x): return x2.sum() with texpr_reductions_enabled(): - a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -205,7 +205,7 @@ def func_neg(x): return x.sum((-2, )) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) scripted = self.checkScript(func, (a,)) self.assertLastGraphAllFused() @@ -217,7 +217,7 @@ def func(x): return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 with texpr_reductions_enabled(): - a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = torch.tensor(list(range(0, 15)), dtype=torch.float, device='cpu') a = a.reshape(5, 3) self.checkScript(func, (a,)) diff --git a/test/test_modules.py b/test/test_modules.py index 6a8e064b11423..2ae17f5f8cf85 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -498,7 +498,7 @@ def test_cpu_gpu_parity(self, device, dtype, module_info, training): # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a # nicer way for eval mode only. # See https://github.com/pytorch/pytorch/issues/79161 - rnn_modules = set([torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM]) + rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM} if (module_info.module_cls in rnn_modules and not training and 'cuda' in device diff --git a/test/test_ops.py b/test/test_ops.py index 21a27790b5ec8..230a2e33fc8c0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1719,7 +1719,7 @@ class TestRefsOpsInfo(TestCase): module_alls = [(path, import_module(f"torch.{path}").__all__) for path in import_paths] ref_ops_names = tuple(itertools.chain.from_iterable( [f"{path}.{op}" for op in module_all] for path, module_all in module_alls)) - ref_db_names = set(ref_op.name for ref_op in python_ref_db) + ref_db_names = {ref_op.name for ref_op in python_ref_db} # TODO: References that do not have an entry in python_ref_db skip_ref_ops = { @@ -1910,9 +1910,7 @@ def test_refs_are_in_decomp_table(self, op): fake_autocast_device_skips = defaultdict(dict) # TODO: investigate/fix -fake_autocast_device_skips["cpu"] = set( - ("linalg.pinv",) -) +fake_autocast_device_skips["cpu"] = {"linalg.pinv"} dynamic_output_op_tests = ( diff --git a/test/test_optim.py b/test/test_optim.py index b8910c300767a..3ea7b49b9216a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -145,8 +145,8 @@ def _test_basic_cases_template( constructor_accepts_maximize=True, constructor_accepts_foreach=False, ): - maximize_options = set([False, constructor_accepts_maximize]) - foreach_options = set([False, constructor_accepts_foreach]) + maximize_options = {False, constructor_accepts_maximize} + foreach_options = {False, constructor_accepts_foreach} four_arg_constructor = constructor if constructor_accepts_maximize and constructor_accepts_foreach: @@ -317,7 +317,7 @@ def fn_base(optimizer, weight, bias): # validate deepcopy() copies all public attributes def getPublicAttr(obj): - return set(k for k in obj.__dict__ if not k.startswith("_")) + return {k for k in obj.__dict__ if not k.startswith("_")} self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) @@ -346,8 +346,8 @@ def make_two_arg_constructor( return constructor for maximize, foreach in itertools.product( - set([False, constructor_accepts_maximize]), - set([False, constructor_accepts_foreach]), + {False, constructor_accepts_maximize}, + {False, constructor_accepts_foreach}, ): self._test_state_dict( torch.randn(10, 5), diff --git a/test/test_reductions.py b/test/test_reductions.py index 6784f0f22c0cb..073b91f3323bb 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -80,7 +80,7 @@ def _reduced_shape(shape, dim=None, keepdim=False): # Wrap negative dims dim = dim if isinstance(dim, Sequence) else [dim] - dim = set(i if i >= 0 else len(shape) + i for i in dim) + dim = {i if i >= 0 else len(shape) + i for i in dim} result = [] for i, size in enumerate(shape): diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 45796d8ffa47f..fc974b2509492 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -19,33 +19,29 @@ # - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration # Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp -MANUAL_BACKEND = set( - [ - "options", - "data", - "set_data", - "is_leaf", - "output_nr", - "_version", - "retain_grad", - "_backward", - "requires_grad_", - ] -) +MANUAL_BACKEND = { + "options", + "data", + "set_data", + "is_leaf", + "output_nr", + "_version", + "retain_grad", + "_backward", + "requires_grad_", +} # For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp -MANUAL_AUTOGRAD_AND_TRACER = set( - [ - "resize_", - "resize_as_", - "detach", - "detach_", - "copy_", - "_fw_primal", - "_make_dual", - ] -) +MANUAL_AUTOGRAD_AND_TRACER = { + "resize_", + "resize_as_", + "detach", + "detach_", + "copy_", + "_fw_primal", + "_make_dual", +} # Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: # union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 66edb8ce30208..2b43df10dc9c5 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -968,10 +968,10 @@ def find_args_with_derivatives( """Find arguments that have derivative definitions""" if info is None or not info.has_derivatives: return differentiable_inputs - names = set(name for d in info.derivatives for name in d.var_names) + names = {name for d in info.derivatives for name in d.var_names} differentiable = [arg for arg in differentiable_inputs if arg.name in names] if len(differentiable) != len(names): - missing = names - set(arg.name for arg in differentiable) + missing = names - {arg.name for arg in differentiable} raise RuntimeError( f"Missing arguments for derivatives: {missing} in {info.name}" ) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4a76a7b816d7b..d8b20a9f932e5 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1408,7 +1408,7 @@ def MATCH_KEYS(self, inst): assert isinstance(tos1, ConstDictVariable) match_obj = tos1.items if all(key in match_obj for key in keys): - self.push(TupleVariable(list(match_obj[key] for key in keys))) + self.push(TupleVariable([match_obj[key] for key in keys])) self.push(ConstantVariable(True)) else: self.push(ConstantVariable(None)) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d7513f393f6de..c48bed0c00099 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -764,14 +764,14 @@ def dict_param_key_ids(value): def dict_const_keys(value): - return set(k for k in value.keys() if not isinstance(k, torch.nn.Parameter)) + return {k for k in value.keys() if not isinstance(k, torch.nn.Parameter)} def dict_const_keys_repr(const_keys): if any(isinstance(k, enum.Enum) for k in const_keys): # To workaround repr(Enum) returning invalid global reference before python 3.11 # by calling enum_repr and removing quotes to render enum in guard code. - const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace( + const_keys_str = f"{ {enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys} }".replace( "'", "" ) else: diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 655d0a7b1b342..67845104b44fa 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -451,7 +451,7 @@ def get_state_from_generator(): for x in args ] ) - bin_ops = set(["add", "sub", "mul", "div", "sqrt"]) + bin_ops = {"add", "sub", "mul", "div", "sqrt"} if ( getattr(self.value, "__module__", "") == "torch" and self.value.__name__ in bin_ops @@ -903,7 +903,7 @@ def speculate_branch(branch): args[0].as_proxy(), true_node, false_node, - list(a.as_proxy() for a in sub_args), + [a.as_proxy() for a in sub_args], ) # TODO: assert that the true/false return values are # consistent diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 80c024740a3b6..03b5563e9966a 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -388,11 +388,11 @@ def is_tensor_node(x): fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: - joint_module_ops = set( + joint_module_ops = { str(node.target._overloadpacket) for node in joint_module.graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") - ) + } ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} print("Ops banned from rematerialization: ", ops_ignored) print() @@ -400,7 +400,7 @@ def is_tensor_node(x): AGGRESSIVE_RECOMPUTATION = False def is_materialized_backwards(node): - cur_nodes = set([node]) + cur_nodes = {node} while len(cur_nodes) > 0: cur = cur_nodes.pop() for user in cur.users: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index ccca91884dfd8..619e8ac0220e6 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -949,7 +949,7 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value): dim = len(self.range_trees) - 1 result_var = self.cse.newvar() - result_var.mask_vars = set(var for var in masks if var[0] != "r") + result_var.mask_vars = {var for var in masks if var[0] != "r"} if (src_dtype, reduction_type, value) not in self.cse.reduction_cache: self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var accumulator = f"_{result_var}" diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 659edeb3b9b7f..1333093ba1430 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -531,15 +531,15 @@ def count_bytes(self): def get_read_write_buffers_sizes(node): if isinstance(node, NopKernelSchedulerNode): return 0 - reads = set(dep.name for dep in node.read_writes.reads) - writes = set(dep.name for dep in node.read_writes.writes) + reads = {dep.name for dep in node.read_writes.reads} + writes = {dep.name for dep in node.read_writes.writes} def is_materialized(buf): buf_uses = {user.node for user in scheduler.name_to_node[buf].users} return len(buf_uses - set(node.snodes)) > 0 if isinstance(node, FusedSchedulerNode): - removed_buffers = set(dep for dep in writes if not is_materialized(dep)) + removed_buffers = {dep for dep in writes if not is_materialized(dep)} writes = writes - removed_buffers reads = reads - removed_buffers node_bytes = 0 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 50c499d0ee19a..df3a67cdbe9b8 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2995,7 +2995,7 @@ def gen_kwarg(k, v): tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] constant_args = [Shim(repr(x)) for x in self.constant_args] args, kwargs = self.unflatten_args(tensor_args, constant_args) - return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items()) + return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()] @classmethod def create(cls, kernel, *args, **kwargs): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 1969d88d19c11..452df067b2171 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -177,7 +177,7 @@ def get_first_name(self) -> str: return self.get_name() def get_names(self) -> Set[str]: - return set([self.get_name()]) + return {self.get_name()} def get_nodes(self) -> List["BaseSchedulerNode"]: return [self] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f36af67a356c5..dc48ed389894c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -295,14 +295,12 @@ def free_symbol_startswith(index: sympy.Expr, prefix: str): def has_incompatible_cudagraph_ops(gm): - forbidden_list = set( - [ - "aten._fused_moving_avg_obs_fq_helper.default", - "aten._fused_moving_avg_obs_fq_helper_functional.default", - "fbgemm.dense_to_jagged.default", - "fbgemm.jagged_to_padded_dense.default", - ] - ) + forbidden_list = { + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + } for node in gm.graph.nodes: if str(node.target) in forbidden_list: return True diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 0ba2a5a0234a7..d7713413463f6 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -243,14 +243,12 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool: return True -_memory_formats = set( - ( - torch.contiguous_format, - torch.preserve_format, - torch.channels_last, - torch.channels_last_3d, - ) -) +_memory_formats = { + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, +} def validate_memory_format(memory_format: torch.memory_format): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 4b0c9a63fbb8e..9ada634e412b4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2956,7 +2956,7 @@ def native_group_norm( out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) out = out.view(input.shape) - broadcast_dims = [0] + list(dim for dim in range(2, input.ndim)) + broadcast_dims = [0] + list(range(2, input.ndim)) unsqueeze_bias = None if bias is not None: unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index d71488ae3d785..b0af9e669876f 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -816,13 +816,9 @@ def from_float(cls, mod): return super(ConvReLU3d, cls).from_float(mod) def update_bn_stats(mod): - if type(mod) in set( - [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d] - ): + if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}: mod.update_bn_stats() def freeze_bn_stats(mod): - if type(mod) in set( - [ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d] - ): + if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}: mod.freeze_bn_stats() diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index 09d0e535aaf0f..9cdaac1205df6 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -267,10 +267,8 @@ def weight_bias_name(ihhh, layer, suffix): @classmethod def from_float(cls, mod): - assert type(mod) in set( - [torch.nn.LSTM, - torch.nn.GRU] - ), 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU' + assert type(mod) in {torch.nn.LSTM, + torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU' assert hasattr( mod, 'qconfig' @@ -823,9 +821,9 @@ def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '' @classmethod def from_float(cls, mod): - assert type(mod) in set([torch.nn.LSTMCell, - torch.nn.GRUCell, - torch.nn.RNNCell]), 'nn.quantized.dynamic.RNNCellBase.from_float \ + assert type(mod) in {torch.nn.LSTMCell, + torch.nn.GRUCell, + torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \ only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell' assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index b196e99ca5fba..3f0df31dfd2a1 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -222,12 +222,12 @@ def forward(self, x): def _convert_tuple_to_list(t: Any) -> Any: - return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t + return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t def _dequantize_tensor_list(t: Any) -> Any: return ( - list(_dequantize_tensor_list(x) for x in t) + [_dequantize_tensor_list(x) for x in t] if type(t) is list else t.dequantize() if t.is_quantized diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 3000f90a22e6b..ca04ac4d3ba90 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -27,303 +27,303 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: # note: this set is modified below by items from backend_config sets_of_related_ops: List[Set[NSNodeTargetType]] = [ # conv modules - set([ + { nn.Conv1d, - ]), - set([ + }, + { nn.Conv2d, - ]), - set([ + }, + { nn.Conv3d, - ]), + }, # conv functionals - set([ + { F.conv1d, - ]), - set([ + }, + { F.conv2d, - ]), - set([ + }, + { F.conv3d, - ]), + }, # linear modules - set([ + { nn.Linear, - ]), + }, # linear functionals - set([ + { F.linear, - ]), + }, # average pool - set([ + { nn.AvgPool1d, torch.avg_pool1d, - ]), - set([ + }, + { nn.AvgPool2d, torch._C._nn.avg_pool2d, - ]), - set([ + }, + { nn.AvgPool3d, torch._C._nn.avg_pool3d, - ]), + }, # adaptive average pool - set([ + { nn.AdaptiveAvgPool1d, F.adaptive_avg_pool1d, - ]), - set([ + }, + { nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d, - ]), - set([ + }, + { nn.AdaptiveAvgPool3d, F.adaptive_avg_pool3d, - ]), + }, # LSTM - set([ + { nn.LSTM, - ]), + }, # add - set([ + { torch.add, operator.add, # x + y - ]), + }, # cat - set([ + { torch.cat, - ]), + }, # mul - set([ + { torch.mul, operator.mul, - ]), + }, # relu - set([ + { F.relu, nn.ReLU, 'relu', 'relu_', torch.relu, - ]), + }, # maxpool - set([ + { nn.MaxPool1d, F.max_pool1d, - ]), - set([ + }, + { nn.MaxPool2d, F.max_pool2d, - ]), - set([ + }, + { nn.MaxPool3d, F.max_pool3d, - ]), + }, # sigmoid - set([ + { torch.sigmoid, 'sigmoid', 'sigmoid_', nn.Sigmoid, F.sigmoid, - ]), + }, # BatchNorm - set([ + { nn.BatchNorm2d, - ]), - set([ + }, + { nn.BatchNorm3d, - ]), + }, # ConvTranspose - set([ + { nn.ConvTranspose1d, - ]), - set([ + }, + { nn.ConvTranspose2d, - ]), - set([ + }, + { nn.ConvTranspose3d, - ]), + }, # ELU - set([ + { nn.ELU, - ]), + }, # Embedding - set([ + { nn.Embedding, - ]), + }, # EmbeddingBag - set([ + { nn.EmbeddingBag, - ]), + }, # GroupNorm - set([ + { nn.GroupNorm, - ]), + }, # Hardswish - set([ + { nn.Hardswish, - ]), + }, # InstanceNorm - set([ + { nn.InstanceNorm1d, - ]), - set([ + }, + { nn.InstanceNorm2d, - ]), - set([ + }, + { nn.InstanceNorm3d, - ]), + }, # LayerNorm - set([ + { nn.LayerNorm, - ]), + }, # LeakyReLU - set([ + { nn.LeakyReLU, - ]), + }, # ReLU6 - set([ + { nn.ReLU6, F.relu6, - ]), + }, # F.elu - set([ + { F.elu, - ]), + }, # F.hardswish - set([ + { F.hardswish, - ]), + }, # F.group_norm - set([ + { F.group_norm, - ]), + }, # F.instance_norm - set([ + { F.instance_norm, - ]), + }, # F.layer_norm - set([ + { F.layer_norm, - ]), + }, # F.leaky_relu - set([ + { F.leaky_relu, - ]), + }, # F.silu - set([ + { nn.SiLU, F.silu, - ]), + }, # F.mish - set([ + { nn.Mish, F.mish, - ]), + }, # F.tanh - set([ + { nn.Tanh, F.tanh, torch.tanh, 'tanh_', 'tanh', - ]), + }, # F.hardsigmoid - set([ + { 'hardsigmoid_', 'hardsigmoid', F.hardsigmoid, nn.Hardsigmoid, - ]), + }, # F.hardtanh - set([ + { nn.Hardtanh, F.hardtanh, F.hardtanh_, - ]), + }, # floordiv - set([ + { operator.floordiv, - ]), + }, # unsqueeze - set([ + { torch.unsqueeze, - ]), + }, # stack - set([ + { torch.stack, - ]), + }, # squeeze - set([ + { torch.squeeze, - ]), + }, # sort - set([ + { torch.sort, - ]), + }, # repeat_interleave - set([ + { torch.repeat_interleave, - ]), + }, # min - set([ + { torch.min, - ]), + }, # mean - set([ + { torch.mean, - ]), + }, # max - set([ + { torch.max, - ]), + }, # transpose - set([ + { torch.transpose, - ]), + }, # flatten - set([ + { torch.flatten, - ]), + }, # clamp - set([ + { torch.clamp, - ]), + }, # chunk - set([ + { torch.chunk, - ]), + }, # interpolate - set([ + { torch.nn.functional.interpolate, - ]), + }, # dropout - set([ + { nn.Dropout, - ]), + }, # F.dropout - set([ + { F.dropout, - ]), + }, # matmul - set([ + { torch.matmul, - ]), + }, # Softmax - set([ + { nn.Softmax, - ]), + }, # PReLU - set([ + { nn.PReLU, nnq.PReLU, - ]), + }, # F.prelu - set([ + { F.prelu, toq.prelu, - ]), + }, ] # for each floating point op, add versions of the op added by @@ -453,12 +453,12 @@ def add_op_to_sets_of_related_ops( counter = 0 while str(counter) in base_name_to_sets_of_related_ops: counter += 1 - base_name_to_sets_of_related_ops[str(counter)] = set([op]) + base_name_to_sets_of_related_ops[str(counter)] = {op} # TODO(future PR): clean this up def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: - FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([ + FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = { F.linear, F.conv1d, F.conv2d, @@ -478,11 +478,11 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: torch.mul, torch.sum, F.prelu, - ]) + } FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() - FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ + FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = { toq.linear, toq.linear_relu, toq.conv1d, @@ -503,9 +503,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: # uncomment below # toq.add, # toq.mul, - ]) + } - FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ + FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { F.relu, F.tanh, torch.tanh, @@ -541,9 +541,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: torch.stack, torch.unsqueeze, operator.add, - ]) + } - MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([ + MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = { nn.Linear, nnqat.Linear, nnqatd.Linear, @@ -606,9 +606,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nni.LinearTanh, nni.ConvAdd2d, nni.ConvAddReLU2d, - ]) + } - MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ + MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = { nnq.Linear, nnq.Conv1d, nnq.Conv2d, @@ -640,9 +640,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nniq.LinearTanh, nniq.ConvAdd2d, nniq.ConvAddReLU2d, - ]) + } - MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ + MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { nn.ReLU, nn.Tanh, nn.Sigmoid, @@ -660,9 +660,9 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nn.MaxPool2d, nn.MaxPool3d, nn.ReLU6, - ]) + } - METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = set([ + METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { 'sigmoid_', 'sigmoid', 'tanh_', @@ -671,7 +671,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: 'hardsigmoid', 'relu_', 'relu', - ]) + } return { 'funs_io_type_fp32': FUNS_IO_TYPE_FP32, @@ -687,16 +687,16 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: - FUNS_UNMATCHABLE: Set[NSNodeTargetType] = set([ + FUNS_UNMATCHABLE: Set[NSNodeTargetType] = { torch.quantize_per_tensor, operator.getitem, - ]) + } - MODS_UNMATCHABLE: Set[NSNodeTargetType] = set([ + MODS_UNMATCHABLE: Set[NSNodeTargetType] = { nn.Identity, - ]) + } - METHS_UNMATCHABLE: Set[NSNodeTargetType] = set([ + METHS_UNMATCHABLE: Set[NSNodeTargetType] = { 'to', 'dequantize', 'reshape', @@ -719,7 +719,7 @@ def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: 'contiguous', 'clamp', 'chunk', - ]) + } return { 'funs_unmatchable': FUNS_UNMATCHABLE, diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 495986a1b9cb7..a5a5921cbd99a 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -991,9 +991,9 @@ def extract_weight_comparison(m: GraphModule) -> NSResultsType: # use functions. # TODO(future PR): move this to config - weighted_ops = set([ + weighted_ops = { torch.nn.functional.linear, - ]) + } results: NSResultsType = { 'model': {NSSingleResultValuesType.WEIGHT.value: {}} diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index fa5f3e6728ef8..bbca4609a2c66 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -219,10 +219,10 @@ class PerChannelDetector(DetectorBase): # Default map for representing supported per channel quantization modules for different backends DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = { - "fbgemm": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), - "qnnpack": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), - "onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), - "x86": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]), + "fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, } def __init__(self, backend: str = torch.backends.quantized.engine): @@ -230,7 +230,7 @@ def __init__(self, backend: str = torch.backends.quantized.engine): # store the backend information self.backend_chosen = backend - self.supported_modules = set([]) + self.supported_modules = set() if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen] else: @@ -413,17 +413,17 @@ class DynamicStaticDetector(DetectorBase): IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" # modules that are supported both dynamic and static for this report function - DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = set([nn.Linear]) + DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} # modules that will be supported soon for both - DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = set([nn.Conv1d, nn.Conv2d, nn.Conv3d]) + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} def __init__(self, tolerance=0.5): super().__init__() # set tolerance level and initialize a set to keep track of useful fqn locations self.tolerance = tolerance - self.useful_observer_fqns: Set[str] = set([]) + self.useful_observer_fqns: Set[str] = set() def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: r""" @@ -737,9 +737,14 @@ class InputWeightEqualizationDetector(DetectorBase): * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector """ - SUPPORTED_MODULES: Set[Callable] = set( - [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d] - ) + SUPPORTED_MODULES: Set[Callable] = {nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d} # names for the pre and post observers that are inserted DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index 27a9aa3d05ba3..8bc2aec135036 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -129,7 +129,7 @@ def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBas # initialize each report to have empty set of observers of interest for desired_report in self._desired_detector_names: - self._detector_name_to_observer_fqns[desired_report] = set([]) + self._detector_name_to_observer_fqns[desired_report] = set() # flags to ensure that we can only prepare and remove observers once self._prepared_flag = False @@ -287,7 +287,7 @@ def generate_model_report( if remove_inserted_observers: self._removed_observers = True # get the set of all Observers inserted by this instance of ModelReport - all_observers_of_interest: Set[str] = set([]) + all_observers_of_interest: Set[str] = set() for desired_report in self._detector_name_to_observer_fqns: observers_of_interest = self._detector_name_to_observer_fqns[desired_report] all_observers_of_interest.update(observers_of_interest) diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py index 32768c61045ee..cc9187285ae63 100644 --- a/torch/ao/quantization/fx/graph_module.py +++ b/torch/ao/quantization/fx/graph_module.py @@ -30,7 +30,7 @@ def __deepcopy__(self, memo): class ObservedGraphModule(GraphModule): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): - self.preserved_attr_names = set([ + self.preserved_attr_names = { '_activation_post_process_map', '_activation_post_process_indexes', '_patterns', @@ -40,7 +40,7 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, p '_node_name_to_scope', '_qconfig_mapping', '_is_qat', - '_observed_node_names']).union(preserved_attr_names) + '_observed_node_names'}.union(preserved_attr_names) preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)} super().__init__(root, graph) for attr in preserved_attrs: @@ -64,9 +64,9 @@ def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule], class ObservedStandaloneGraphModule(ObservedGraphModule): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): - preserved_attr_names = preserved_attr_names.union(set([ + preserved_attr_names = preserved_attr_names.union({ "_standalone_module_input_quantized_idxs", - "_standalone_module_output_quantized_idxs"])) + "_standalone_module_output_quantized_idxs"}) super().__init__(root, graph, preserved_attr_names) def __deepcopy__(self, memo): diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index 8b4d66e4aa77d..96db52624acd3 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -208,10 +208,10 @@ def no_observer_set() -> Set[Any]: r"""These modules cannot have observers inserted by default.""" - no_observers = set([ + no_observers = { nn.quantizable.LSTM, nn.quantizable.MultiheadAttention - ]) + } return no_observers def get_default_static_quant_module_mappings() -> Dict[Callable, Any]: diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index dd56ff517b61a..4ee98d42f9282 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1609,8 +1609,8 @@ def gradgradcheck( # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs # before running forward mode AD - diff_input_args_indices = set(i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad) - diff_grad_output_indices = set(i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad) + diff_input_args_indices = {i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad} + diff_grad_output_indices = {i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad} def new_func(*args): # Restore the requires_grad information diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index b13b3dc8e7830..6498d5c9b5b49 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -491,7 +491,7 @@ def _parse_visible_devices() -> Set[int]: """Parse CUDA_VISIBLE_DEVICES environment variable.""" var = os.getenv("CUDA_VISIBLE_DEVICES") if var is None: - return set(x for x in range(64)) + return set(range(64)) def _strtoul(s: str) -> int: """Return -1 or positive integer sequence string starts with,""" diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index f034639cceba8..dc7ebc67d8a88 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -85,11 +85,11 @@ def _seg_info(seg): f = io.StringIO() - before_segs = set(_seg_key(seg) for seg in before) - after_segs = set(_seg_key(seg) for seg in after) + before_segs = {_seg_key(seg) for seg in before} + after_segs = {_seg_key(seg) for seg in after} - print(f'only_before = {list(a for a,_ in (before_segs - after_segs))}') - print(f'only_after = {list(a for a,_ in (after_segs - before_segs))}') + print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}') + print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}') for seg in before: if _seg_key(seg) not in after_segs: diff --git a/torch/distributed/_composable/_ddp.py b/torch/distributed/_composable/_ddp.py index 1704e0854bfde..4a20665b7aae7 100644 --- a/torch/distributed/_composable/_ddp.py +++ b/torch/distributed/_composable/_ddp.py @@ -383,7 +383,7 @@ def _build_params_for_reducer(self): ] # Build list of parameters. - parameters = list(parameter for _, parameter in modules_and_parameters) + parameters = [parameter for _, parameter in modules_and_parameters] # Checks if a module will produce a sparse gradient. def produces_sparse_gradient(module): @@ -393,9 +393,9 @@ def produces_sparse_gradient(module): # Build list of booleans indicating whether or not to expect sparse # gradients for the corresponding parameters. - expect_sparse_gradient = list( + expect_sparse_gradient = [ produces_sparse_gradient(module) for module, _ in modules_and_parameters - ) + ] self._assign_modules_buffers() diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py index b6125e69b16e8..e38f1dc15e7ca 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py @@ -281,7 +281,7 @@ def _handle_row_wise_sharding_tensor( indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size) ) - indices_flatten = list(idx for indice in indices for idx in indice) + indices_flatten = [idx for indice in indices for idx in indice] input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device) diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py index e583a52d23e0a..9a51986a08fd5 100644 --- a/torch/distributed/_tensor/dispatch.py +++ b/torch/distributed/_tensor/dispatch.py @@ -38,10 +38,10 @@ def wrap(res: object, spec: OutputSpecType) -> object: assert spec is not None and isinstance( spec, list ), f"output spec does not match with output! Expected list, got {spec}." - return list( + return [ dtensor.DTensor(e, s.mesh, s.placements, size=s.shape) for e, s in zip(res, spec) - ) + ] elif isinstance(res, tuple): assert spec is not None and isinstance( spec, tuple diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 2eb6c300036b1..5856bcca5642c 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -397,7 +397,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding: assert isinstance(indices_output_spec, DTensorSpec) indices_spec = indices_output_spec - lookup_dims = set(v[0] for v in valid_indices_spec) + lookup_dims = {v[0] for v in valid_indices_spec} need_reshard_on_values = tuple( (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 9999ee320d979..f7f6f290c18f0 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -370,7 +370,7 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: dim2 = normalize_dim(dim2, ndim) assert dim1 < ndim assert dim2 < ndim - dimmap = list(InputDim(i) for i in range(ndim)) + dimmap = [InputDim(i) for i in range(ndim)] swapdim = dimmap[dim1] dimmap[dim1] = dimmap[dim2] dimmap[dim2] = swapdim @@ -480,7 +480,7 @@ def propagate_shape_and_sharding( if the leftmost split size is divisible by the mesh dimension """ assert len(in_shard) == len(mesh_sizes) - sharded_in_dims: Set[int] = set(s.dim for s in in_shard if isinstance(s, Shard)) + sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)} # for each input dim, for each mesh dim, provides a list of possible shardable dimensions shardable_dims: torch.Tensor = torch.ones( (len(local_in_shape), len(mesh_sizes)), dtype=torch.bool diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 1ee50e74304ad..f806318774077 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -567,12 +567,12 @@ def _get_ignored_modules( # that this FSDP instance can get any ignored modules from its children. # Include child modules and exclude nested FSDP modules themselves - ignored_modules = set( + ignored_modules = { child for module in ignored_root_modules for child in module.modules() if not isinstance(child, fsdp_file.FullyShardedDataParallel) - ) + } if root_module in ignored_modules: warnings.warn( "Trying to ignore the top-level module passed into the FSDP " @@ -599,16 +599,16 @@ def _get_ignored_params( """ all_ignored_params: Set[torch.nn.Parameter] = set() - params_in_ignored_modules = set( + params_in_ignored_modules = { p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) - ) + } all_ignored_params.update(params_in_ignored_modules) if ignored_parameters is not None: - params_in_ignored_parameters = set( + params_in_ignored_parameters = { p for p in ignored_parameters if not _is_fsdp_flattened(p) - ) + } all_ignored_params.update(params_in_ignored_parameters) # Include nested FSDP modules' ignored parameters @@ -626,9 +626,9 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]: Returns the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`. """ - return set( + return { clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() - ) + } def _check_single_device_module( @@ -640,7 +640,7 @@ def _check_single_device_module( ignoring the parameters in ``ignored_params``. Thus, after this method, the module must be either fully on the CPU or fully on a non-CPU device. """ - devices = set(param.device for param in _get_orig_params(module, ignored_params)) + devices = {param.device for param in _get_orig_params(module, ignored_params)} if len(devices) > 1: raise RuntimeError( f"FSDP only supports single device modules but got params on {devices}" diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 736984f5c7175..1353391cc965a 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -485,7 +485,7 @@ def _flatten_optim_state( are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 are_zero_dim_tensors &= _is_zero_dim_tensor(v) are_non_tensors &= not torch.is_tensor(v) - types = set(type(v) for v in non_none_state_values) + types = {type(v) for v in non_none_state_values} if len(types) != 1 or not ( are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors ): @@ -570,7 +570,7 @@ def _flatten_tensor_optim_state( """ non_none_tensors = [t for t in pos_dim_tensors if t is not None] # Check that all are tensors with the same dtype - dtypes = set(t.dtype for t in non_none_tensors) + dtypes = {t.dtype for t in non_none_tensors} if len(dtypes) != 1: raise ValueError( "All unflattened parameters comprising a single flattened " @@ -648,8 +648,8 @@ def _flatten_zero_dim_tensor_optim_state( """ non_none_tensors = [t for t in zero_dim_tensors if t is not None] # Enforce that all have the same value and dtype - values_set = set(t.item() if t is not None else None for t in zero_dim_tensors) - dtypes = set(t.dtype if t is not None else None for t in zero_dim_tensors) + values_set = {t.item() if t is not None else None for t in zero_dim_tensors} + dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors} if ( len(non_none_tensors) != len(zero_dim_tensors) or len(values_set) != 1 @@ -1004,10 +1004,10 @@ def _rekey_sharded_optim_state_dict( for unflat_param_group in sharded_osd["param_groups"]: flat_param_group = copy.deepcopy(unflat_param_group) flat_param_keys = sorted( - set( + { unflat_param_name_to_flat_param_key[unflat_param_name] for unflat_param_name in unflat_param_group["params"] - ) + } ) flat_param_group["params"] = flat_param_keys rekeyed_osd_param_groups.append(flat_param_group) diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 9d27f5e5bf52b..b7a13689e4ff7 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1068,7 +1068,7 @@ def _get_training_state( ) -> HandleTrainingState: """Returns the training state of the handles in ``handles_key``.""" p_assert(len(handles_key) > 0, "Expects a non-empty handles key") - training_states = set(handle._training_state for handle in handles_key) + training_states = {handle._training_state for handle in handles_key} p_assert( len(training_states) == 1, f"Expects uniform training state but got {training_states}", diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 3bdac64adbc3c..a70d6fbd32613 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -274,8 +274,8 @@ def _init_metadata( self._fqns = tuple(fqns) self._shared_param_infos = tuple(shared_param_infos) self._param_extensions = tuple(param_extensions) - self._modules = set(pi.module for pi in self._param_infos).union( - set(spi.module for spi in self._shared_param_infos) + self._modules = {pi.module for pi in self._param_infos}.union( + {spi.module for spi in self._shared_param_infos} ) assert (params is None) == (shared_params is None) if params is not None: @@ -1857,8 +1857,8 @@ def flat_param_to(self, *args, **kwargs): def _get_modules(self) -> Set[nn.Module]: """Returns a :class:`set` of the modules whose parameters are included in this handle's flattened parameter.""" - return set(pi.module for pi in self.flat_param._param_infos).union( - set(spi.module for spi in self.flat_param._shared_param_infos) + return {pi.module for pi in self.flat_param._param_infos}.union( + {spi.module for spi in self.flat_param._shared_param_infos} ) def is_sharded(self, tensor: Tensor) -> bool: diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index c5396a1ea736f..996d92f8cb709 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1968,7 +1968,7 @@ def _get_grad_norm( if len(params_with_grad) == 0: return torch.tensor(0.0) grads = [param.grad for param in params_with_grad] - grad_dtypes = set(grad.dtype for grad in grads) + grad_dtypes = {grad.dtype for grad in grads} if len(grad_dtypes) != 1: raise ValueError( f"Requires uniform dtype across all gradients but got {grad_dtypes}" diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 3e3607b3f390c..5a4d6ce1b546c 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -54,7 +54,7 @@ def register_rendezvous_handler(scheme, handler): # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} def _query_to_dict(query: str) -> Dict[str, str]: - return dict((pair[0], pair[1]) for pair in (pair.split("=") for pair in filter(None, query.split("&")))) + return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index cd2267c701fb6..3b5d5afe0f208 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -275,7 +275,7 @@ def check_dependency(partition): """Given a partition,check if there is a circular dependency on this partition using bfs """ - visited: Set[Partition] = set([partition]) + visited: Set[Partition] = {partition} queue: Deque[Partition] = deque([partition]) while queue: p = queue.popleft() diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 32116f93c30f4..3a0e572c09eb7 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -30,7 +30,7 @@ def _reify(t, s): @dispatch(dict, dict) # type: ignore[no-redef] def _reify(d, s): - return dict((k, reify(v, s)) for k, v in d.items()) + return {k: reify(v, s) for k, v in d.items()} _reify @dispatch(object, dict) # type: ignore[no-redef] diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index e7890986636c8..c4fd64c64acf1 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -55,7 +55,7 @@ class VarDispatcher(Dispatcher): """ def __call__(self, *args, **kwargs): func, s = self.resolve(args) - d = dict((k.token, v) for k, v in s.items()) + d = {k.token: v for k, v in s.items()} return func(**d) @@ -86,7 +86,7 @@ def supercedes(a, b): s = unify(a, b) if s is False: return False - s = dict((k, v) for k, v in s.items() if not isvar(k) or not isvar(v)) + s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} if reify(a, s) == a: return True if reify(b, s) == b: @@ -117,5 +117,5 @@ def ordering(signatures): for s in signatures: if s not in edges: edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[attr-defined, assignment] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] return _toposort(edges) diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 5aa0c0ed19ed6..2eaf6141b18b4 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -80,11 +80,11 @@ def ambiguous(a, b): def ambiguities(signatures): """ All signature pairs such that A is ambiguous with B """ signatures = list(map(tuple, signatures)) - return set((a, b) for a in signatures for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) - for c in signatures)) + return {(a, b) for a in signatures for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) + for c in signatures)} def super_signature(signatures): @@ -92,7 +92,7 @@ def super_signature(signatures): n = len(signatures[0]) assert all(len(s) == n for s in signatures) - return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] @@ -115,5 +115,5 @@ def ordering(signatures): for s in signatures: if s not in edges: edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[assignment, attr-defined] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] return _toposort(edges) diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 2eda80f4ee868..d74799a714c5d 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -45,8 +45,8 @@ def _toposort(edges): [2] http://en.wikipedia.org/wiki/Toposort#Algorithms """ incoming_edges = reverse_dict(edges) - incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items()) - S = set((v for v in edges if v not in incoming_edges)) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = ({v for v in edges if v not in incoming_edges}) L = [] while S: diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index fdfdc791569b5..bfbefcae8619e 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -11,9 +11,9 @@ # stateful ops are banned from CSE -rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501 +rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501 -inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501 +inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 @torch.fx._compatibility.compatibility(is_backward_compatible=False) diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 3271e652fde11..bb5839f98cb42 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -468,10 +468,10 @@ def f(x): # so we know not to re-inplace them. # NOTE: later, we'll need to add an optimization for fully recovering performance # on programs that mutate inputs. - input_storages = set( + input_storages = { StorageWeakRef( node.meta['fake_result']._typed_storage() - ) for node in gm.graph.nodes if node.op == 'placeholder') + ) for node in gm.graph.nodes if node.op == 'placeholder'} # We also need to know for a given node, what are all of its aliasing nodes. @@ -627,14 +627,14 @@ def replace_arg(a): old_flattened_res, _ = tree_flatten(old.meta['fake_result']) node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result']) - old_res_storage = set( + old_res_storage = { StorageWeakRef( x._typed_storage() - ) for x in old_flattened_res if isinstance(x, FakeTensor)) - node_res_storage = set( + ) for x in old_flattened_res if isinstance(x, FakeTensor)} + node_res_storage = { StorageWeakRef( x._typed_storage() - ) for x in node_flattened_res if isinstance(x, FakeTensor)) + ) for x in node_flattened_res if isinstance(x, FakeTensor)} # This will happen if we're updating a view op, e.g. # e.g. replacing @@ -648,10 +648,10 @@ def replace_arg(a): # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: new_flattened_res, _ = tree_flatten(new.meta['fake_result']) - new_res_storage = set( + new_res_storage = { StorageWeakRef( x._typed_storage() - ) for x in new_flattened_res if isinstance(x, FakeTensor)) + ) for x in new_flattened_res if isinstance(x, FakeTensor)} assert len(new_res_storage) == 1 (old_ref,) = old_res_storage (new_ref,) = new_res_storage diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 26c340efa36fc..f2c45ab5acd56 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -229,7 +229,7 @@ def generate_inputs_for_submodules( handles = [] results = {} - submodule_to_names = dict((mod, name) for name, mod in model.named_modules()) + submodule_to_names = {mod: name for name, mod in model.named_modules()} def pre_forward(module, module_inputs): results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index e54a14356f07f..777a531d077df 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -117,7 +117,7 @@ def _gen_torch_functional_registered_ops(): # some functions directly map to their aten:: implementations. # TODO: add support for more ops ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"] - return set(getattr(torch.functional, name) for name in ops) + return {getattr(torch.functional, name) for name in ops} _functional_registered_ops = _gen_torch_functional_registered_ops() diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 8ac426ca736b3..5d3a1c5c5d0c3 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -89,7 +89,7 @@ def jit_ignored_properties(module): user_annotated_ignored_attributes = getattr(module, "__jit_ignored_attributes__", list()) def get_properties_names(module): - return set(k for k, v in vars(module).items() if isinstance(v, property)) + return {k for k, v in vars(module).items() if isinstance(v, property)} properties = get_properties_names(type(module)) user_annoted_ignored_properties = set() diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 9d13d159f18ec..0295c20ec9649 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -352,7 +352,7 @@ def try_ann_to_type(ann, loc): return OptionalType(valid_type) if is_union(ann): # TODO: this is hack to recognize NumberType - if set(ann.__args__) == set([int, float, complex]): + if set(ann.__args__) == {int, float, complex}: return NumberType.get() inner: List = [] # We need these extra checks because both `None` and invalid diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index e1364f4538d5b..29d910051cfd9 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -14,7 +14,7 @@ def func(x): return x.{op}() ''') - deprecated_apis = set(["volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"]) + deprecated_apis = {"volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"} tensor_attrs = tensor_attrs - deprecated_apis properties = [] diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 4b81a9a8bb10f..a1b44f328427c 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -378,11 +378,11 @@ def _generate_docstring(func): ) # Apply function name info to docstring templates: - templates = dict( - (k, v.format_map(template_data)) + templates = { + k: v.format_map(template_data) for k, v in docstring_templates.items() if k.startswith(op_kind) - ) + } templates.update( (k, v.format_map(template_data) if isinstance(v, str) else v) for k, v in template_data.items() diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 0459f24587bd7..ae1c46d2bf824 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -90,7 +90,7 @@ def _helper(a, map_fn): def _wrap_result(result_data, result_mask): if isinstance(result_data, list): - return list(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)) + return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)] if isinstance(result_data, tuple): return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)) if torch.is_tensor(result_data): diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 91e517486283b..87304d2456441 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -173,7 +173,7 @@ def flatten_parameters(self) -> None: # a sufficient check, because overlapping parameter buffers that don't completely # alias would break the assumptions of the uniqueness check in # Module.named_parameters(). - unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights) + unique_data_ptrs = {p.data_ptr() for p in self._flat_weights} if len(unique_data_ptrs) != len(self._flat_weights): return diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index ea3f536501890..742b3bb3bf5a4 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -929,7 +929,7 @@ def _build_params_for_reducer(self): ] # Build list of parameters. - parameters = list(parameter for _, parameter in modules_and_parameters) + parameters = [parameter for _, parameter in modules_and_parameters] # Checks if a module will produce a sparse gradient. def produces_sparse_gradient(module): @@ -939,10 +939,10 @@ def produces_sparse_gradient(module): # Build list of booleans indicating whether or not to expect sparse # gradients for the corresponding parameters. - expect_sparse_gradient = list( + expect_sparse_gradient = [ produces_sparse_gradient(module) for module, _ in modules_and_parameters - ) + ] self._assign_modules_buffers() diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index e12739a13a8a9..1c65dbaf9b52b 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -296,7 +296,7 @@ def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]: Check that the given keys are valid. """ keys = set(keys) - valid_keys = set(name for name, _ in self.named_tensors(remove_duplicate=False)) + valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)} missing_keys = valid_keys - keys unexpected_keys = keys - valid_keys return sorted(missing_keys), sorted(unexpected_keys) diff --git a/torch/onnx/_internal/diagnostics/infra/engine.py b/torch/onnx/_internal/diagnostics/infra/engine.py index c2ac449ac6458..001d52b4a73d5 100644 --- a/torch/onnx/_internal/diagnostics/infra/engine.py +++ b/torch/onnx/_internal/diagnostics/infra/engine.py @@ -197,7 +197,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def sarif(self) -> sarif.Run: """Returns the SARIF Run object.""" - unique_rules = set(diagnostic.rule for diagnostic in self.diagnostics) + unique_rules = {diagnostic.rule for diagnostic in self.diagnostics} return sarif.Run( tool=sarif.Tool( driver=sarif.ToolComponent( diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index bb08162039678..84ac973bc8ceb 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -914,7 +914,7 @@ def verify_aten_graph( graph = graph.copy() # Execute aten graph and get reference torch jit outputs. - graph_inputs = list(v for v in graph.inputs()) + graph_inputs = list(graph.inputs()) jit_inputs = tuple([arg for arg in input_args if arg is not None]) weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] assert all([w is not None for w in weights]) @@ -940,7 +940,7 @@ def verify_aten_graph( # NOTE: Verification is unstable. Try catch to emit information for debugging. try: # NOTE: Input might be dce'ed, so we need to remove those from the input args. - new_input_names = set(v.debugName() for v in graph.inputs()) + new_input_names = {v.debugName() for v in graph.inputs()} new_input_args = [] for v, arg in zip(original_jit_graph.inputs(), input_args): if v.debugName() in new_input_names: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 73cdd909c8977..cd897c35a5d4b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7919,9 +7919,7 @@ def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): 'nn.functional.max_unpool3d': 3 } - unpool_to_pool_name_dict = dict(( - (k, f'nn.functional.{v.__name__}') for k, v in unpool_name_to_pool_method_dict.items() - )) + unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} pool_dim = unpool_name_to_dim[op_info.name] pool_method = unpool_name_to_pool_method_dict[op_info.name] diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index 069420bec4f7e..26f2984ec1ace 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -507,7 +507,7 @@ def maybe_tangent(t): if isinstance(t, torch.Tensor) and t.requires_grad: return torch.randn_like(t) elif is_tensorlist(t): - return list(torch.randn_like(e) if e.requires_grad else None for e in t) + return [torch.randn_like(e) if e.requires_grad else None for e in t] return None tangent_args = tuple(maybe_tangent(arg) for arg in args) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 778700cc84df2..eb5130f296370 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -5822,12 +5822,7 @@ def parse_env(var): params = list(model_DDP.parameters()) num_params = 0 param_size = 0 - params = list( - parameter - for parameter in filter( - lambda parameter: parameter.requires_grad, params - ) - ) + params = list(filter(lambda parameter: parameter.requires_grad, params)) for p in params: num_params += 1 param_size += p.numel() * p.element_size() @@ -6665,7 +6660,7 @@ def _run_uneven_inputs_test( dist.all_gather(tensor_list, final_rank_tensor) max_rank = dist.get_world_size() - 1 self.assertSetEqual( - {max_rank}, set(tensor.item() for tensor in tensor_list) + {max_rank}, {tensor.item() for tensor in tensor_list} ) # Ensure that all models are the same across ranks after all have joined. self.validate_net_equivalence(net) @@ -7298,7 +7293,7 @@ def __init__(self, t): def tuple_and_list_validator(x): self.assertTrue(len(x), expected_len) - self.assertEqual(1, len(set(t.device for t in x))) + self.assertEqual(1, len({t.device for t in x})) self.assertEqual(x[0].device.index, self.rank) return x[0] + x[1] @@ -7317,7 +7312,7 @@ def custom_type_validator(x): def dict_validator(x): self.assertTrue(EXPECTED_FIELDS[0] in x.keys()) self.assertTrue(EXPECTED_FIELDS[1] in x.keys()) - self.assertEqual(1, len(set(t.device for t in x.values()))) + self.assertEqual(1, len({t.device for t in x.values()})) self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank) return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]] @@ -8183,14 +8178,14 @@ def test_monitored_barrier_gloo_subgroup(self): def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks): # tests expected behavior when nonzero rank hangs. nccl_pg = dist.new_group( - ranks=list(i for i in range(int(self.world_size))), + ranks=list(range(int(self.world_size))), # provide sufficient timeout so communicators # can be initialized in ctor. timeout=timedelta(seconds=15), backend=dist.Backend.NCCL, ) gloo_pg = dist.new_group( - ranks=list(i for i in range(int(self.world_size))), + ranks=list(range(int(self.world_size))), backend=dist.Backend.GLOO, ) tensors = [torch.ones(10, device=self.rank) * self.rank] @@ -8256,7 +8251,7 @@ def test_monitored_barrier_allreduce_hang_wait_all_ranks(self): def test_monitored_barrier_gloo_rank_0_timeout(self): # tests error when rank 0 exhausts its given timeout. process_group = dist.new_group( - ranks=list(i for i in range(int(self.world_size))) + ranks=list(range(int(self.world_size))) ) timeout = timedelta(seconds=0) if self.rank == 0: diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index 997006353bfbd..83736b33b316b 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -604,78 +604,78 @@ def test_invalid_devices(self): RuntimeError, r"Expected one of .+ device type at start of device string", ): - list( + [ m.forward() for m in self._create_remote_module_iter( "{}/foo".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex( RuntimeError, r"CUDA error: invalid device ordinal" ): - list( + [ m.forward() for m in self._create_remote_module_iter( "{}/cuda:100".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): - list( + [ m.forward() for m in self._create_remote_module_iter( "{}/cpu2".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): - list( + [ m.forward() for m in self._create_remote_module_iter( "{}/".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", ): - list( + [ m.forward() for m in self._create_remote_module_iter( "{}/cuda:0/cuda:1".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: /. The valid format is '/'", ): - list( + [ m.forward() for m in self._create_remote_module_iter( "/", modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] with self.assertRaisesRegex( ValueError, r"Could not parse remote_device: /cuda:0. The valid format is '/'", ): - list( + [ m.forward() for m in self._create_remote_module_iter( "/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR], ) - ) + ] @skip_if_lt_x_gpu(1) @dist_utils.dist_init diff --git a/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py b/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py index 6586b7824bb35..d050a2138b792 100644 --- a/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py @@ -54,7 +54,7 @@ def test_verify_backend_options(self): @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]) def test_custom_faulty_messages(self): self.assertEqual( - set(["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]), + {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"}, set(self.rpc_backend_options.messages_to_fail), ) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 4c0239ac653ee..d85066930cf1d 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1808,7 +1808,7 @@ def test_profiler_rpc_memory(self): res = fut.wait() function_events = p.function_events - event_cpu_mem_usages = set(event.cpu_memory_usage for event in function_events) + event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} # if cpu_memory_usage was not propagated over the wire, this set would # only contain 0 (indicates no memory being profiled) self.assertNotEqual({0}, event_cpu_mem_usages) @@ -1818,7 +1818,7 @@ def test_profiler_rpc_memory(self): res = fut.wait() function_events = p.function_events - event_cpu_mem_usages = set(event.cpu_memory_usage for event in function_events) + event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} self.assertEqual({0}, event_cpu_mem_usages) @dist_init diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index a429415ea7632..7bf183a5a4534 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -2711,5 +2711,5 @@ def clone_tensor(t): return SampleInput( clone_tensor(sample.input), args=tuple(map(clone_tensor, sample.args)), - kwargs=dict(((k, clone_tensor(v)) for k, v in sample_kwargs.items())), + kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()}, ) diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index ed8b6734ed218..9c7863e6a740e 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -155,7 +155,7 @@ def __init__( trim_significant_figures: bool, highlight_warnings: bool ): - assert len(set(r.label for r in results)) == 1 + assert len({r.label for r in results}) == 1 self.results = results self._colorize = colorize diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 9483a742eddd0..733d5b1a4f2f6 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -41,8 +41,8 @@ def check_backward_validity(inputs: Iterable[Any]) -> None: def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: # This will not error out if "arg" is a CPU tensor or a non-tensor type because # the conditionals short-circuit. - fwd_gpu_devices = list(set(arg.get_device() for arg in args - if isinstance(arg, torch.Tensor) and arg.is_cuda)) + fwd_gpu_devices = list({arg.get_device() for arg in args + if isinstance(arg, torch.Tensor) and arg.is_cuda}) fwd_gpu_states = [] for device in fwd_gpu_devices: diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 1e6a5a8aaa454..11b233f27124e 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1401,7 +1401,7 @@ def load_inline(name, functions = [functions] if isinstance(functions, list): # Make the function docstring the same as the function name. - functions = dict((f, f) for f in functions) + functions = {f: f for f in functions} elif not isinstance(functions, dict): raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") for function_name, docstring in functions.items(): diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index ab5e3fb33b60a..a7cd07179d926 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -101,7 +101,7 @@ def _decompose_type(t, to_list=True): return None ts = [t] # Ignored: Generator has incompatible item type "object"; expected "Type[Any]" - ts = list(TYPE2ABC.get(_t, _t) for _t in ts) # type: ignore[misc] + ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc] return ts diff --git a/torchgen/api/python.py b/torchgen/api/python.py index da461248198fd..f6c2ecc678f6d 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -756,9 +756,9 @@ def signature_from_schema( args.extend(func.arguments.post_tensor_options_kwarg_only) args.extend(func.arguments.out) - input_arg_set = set(a.name for a in func.arguments.flat_positional) - kwarg_only_set = set(a.name for a in func.arguments.flat_kwarg_only) - out_arg_set = set(a.name for a in func.arguments.out) + input_arg_set = {a.name for a in func.arguments.flat_positional} + kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in func.arguments.out} input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) input_kwargs = tuple( @@ -1072,7 +1072,7 @@ def dispatch_lambda_args( method=False, cpp_no_default_args=f.cpp_no_default_args, ) - out_args: Set[str] = set(a.name for a in schema.arguments.out) + out_args: Set[str] = {a.name for a in schema.arguments.out} # Convert from cpp argument to lambda argument def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: diff --git a/torchgen/gen.py b/torchgen/gen.py index e034b62d76d2c..0df9e3e81fcc8 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -1188,8 +1188,8 @@ def compute_declaration_yaml(f: NativeFunction) -> object: # These sets are used to conveniently test if an argument is a # kwarg-only or out argument - kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only) - out_arg_set = set(a.name for a in f.func.arguments.out) + kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in f.func.arguments.out} sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False @@ -2099,21 +2099,19 @@ def gen_aten_interned_strings() -> Dict[str, str]: # These are keywords in C++, so aren't valid symbol names # https://en.cppreference.com/w/cpp/language/operator_alternative - names -= set( - [ - "and", - "and_eq", - "bitand", - "bitor", - "compl", - "not", - "not_eq", - "or", - "or_eq", - "xor", - "xor_eq", - ] - ) + names -= { + "and", + "and_eq", + "bitand", + "bitor", + "compl", + "not", + "not_eq", + "or", + "or_eq", + "xor", + "xor_eq", + } return { "aten_symbols": " \\\n".join( diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 87a1392f7abe4..a7a820e774ad2 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -526,13 +526,9 @@ def map_index( ) -> Dict[OperatorName, BackendMetadata]: return {op: m[op] for op in m if op in op_names} - backend_indices = dict( - ( - k, - map_index(b.index), - ) - for (k, b) in parsed_yaml.backend_indices.items() - ) + backend_indices = { + k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items() + } return native_functions, backend_indices else: return [], {} diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 590c5730b6419..ee8fc0312f872 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -319,14 +319,14 @@ def generate_function( ) } } - tags = set(["generated"]) | set(f.tags & {"nondeterministic_seeded", "view_copy"}) + tags = {"generated"} | set(f.tags & {"nondeterministic_seeded", "view_copy"}) return ( NativeFunction( func=func, use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, # These generated fn's aren't meant to be user friendly- don't generate methods. - variants=set([Variant.function]), + variants={Variant.function}, structured=False, structured_delegate=None, structured_inherits=None,