From 3d82d8d0ed000117f78c49ec684c75f00b371014 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 10 Feb 2023 23:40:26 +0000 Subject: [PATCH] [BE] Enable more flake8-comprehensions checks (#94601) I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR. This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601 Approved by: https://github.com/ezyang --- .flake8 | 2 +- .../microbenchmarks/operator_inp_utils.py | 2 +- .../model_zoo/update-models-from-caffe2.py | 2 +- test/distributed/test_c10d_nccl.py | 2 +- test/functorch/discover_coverage.py | 12 +++---- test/functorch/test_aotdispatch.py | 4 +-- test/functorch/test_minifier.py | 4 +-- test/functorch/xfail_suggester.py | 2 +- test/jit/test_list_dict.py | 2 +- test/mobile/model_test/gen_test_model.py | 2 +- test/onnx/onnx_test_common.py | 2 +- test/package/test_digraph.py | 2 +- .../eager/test_quantize_eager_ptq.py | 8 ++--- test/quantization/fx/test_model_report_fx.py | 4 +-- test/test_namedtuple_return_api.py | 2 +- test/test_proxy_tensor.py | 2 +- test/test_sparse.py | 2 +- torch/_dynamo/skipfiles.py | 12 +++---- torch/_dynamo/utils.py | 4 +-- torch/_dynamo/variables/builder.py | 17 ++++----- torch/_functorch/partitioners.py | 6 ++-- torch/_inductor/graph.py | 4 +-- torch/_inductor/utils.py | 4 ++- .../fx/_model_report/model_report.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 2 +- torch/fx/_symbolic_trace.py | 2 +- torch/testing/_internal/common_utils.py | 4 +-- torchgen/gen_backend_stubs.py | 36 +++++++++---------- torchgen/model.py | 2 +- torchgen/selective_build/selector.py | 2 +- 30 files changed, 71 insertions(+), 82 deletions(-) diff --git a/.flake8 b/.flake8 index a16d89827371fc..d6e1aa0e366184 100644 --- a/.flake8 +++ b/.flake8 @@ -11,7 +11,7 @@ ignore = # these ignores are from flake8-bugbear; please fix! B007,B008, # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 + C400,C401,C402,C405,C407 per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 7b7b9a09e5e644..046a1dd9c9b18f 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -181,7 +181,7 @@ def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None): return out def log_to_file(self, output_filename, *, skip_non_compute_operators=True): - sorted_operators = sorted(list(self.func_db.keys())) + sorted_operators = sorted(self.func_db.keys()) with open(output_filename, "w") as f: for operator in sorted_operators: if skip_non_compute_operators and non_compute_operator(eval(operator)): diff --git a/scripts/model_zoo/update-models-from-caffe2.py b/scripts/model_zoo/update-models-from-caffe2.py index f3b485f495d314..fb58871275ca2b 100644 --- a/scripts/model_zoo/update-models-from-caffe2.py +++ b/scripts/model_zoo/update-models-from-caffe2.py @@ -163,7 +163,7 @@ def tensortype_to_ndarray(tensor_type): def generate_test_input_data(onnx_model, scale): - real_inputs_names = list(set([input.name for input in onnx_model.graph.input]) - set([init.name for init in onnx_model.graph.initializer])) + real_inputs_names = list({input.name for input in onnx_model.graph.input} - {init.name for init in onnx_model.graph.initializer}) real_inputs = [] for name in real_inputs_names: for input in onnx_model.graph.input: diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 920e95630812d1..d1ecdba6da177a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2297,7 +2297,7 @@ def test_ddp_packed_sequence(self): store=store, ) seqs = ["sequence_sequence", "seq", "sequence"] - vocab = [""] + sorted(set([ch for seq in seqs for ch in seq])) + vocab = [""] + sorted({ch for seq in seqs for ch in seq}) vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs] # Set the seed to make the embedding and LSTM deterministic (even # across ranks since DDP broadcasts parameters from rank 0) diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 3f4f74b9224de7..6d1e055d01f210 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -426,7 +426,7 @@ def remove_torch(name): def get_list_of_all_tests(): all_tests = list(tested_overridable_outplace_ops.keys()) - return set([remove_torch(test) for test in all_tests]) + return {remove_torch(test) for test in all_tests} mytest = { @@ -459,11 +459,11 @@ def get_jvp_coverage(subset=None): supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items() if op_to_opinfo[fn][0].supports_forward_ad} - ops = set([remove_torch(test) for test in list(ops_dct.keys())]) - supports_autograd = set([remove_torch(test) - for test in list(supports_autograd_ops_dct.keys())]) - supports_forward_ad = set([remove_torch(test) - for test in list(supports_forwardad_ops_dct.keys())]) + ops = {remove_torch(test) for test in list(ops_dct.keys())} + supports_autograd = {remove_torch(test) + for test in list(supports_autograd_ops_dct.keys())} + supports_forward_ad = {remove_torch(test) + for test in list(supports_forwardad_ops_dct.keys())} assert supports_forward_ad.issubset(supports_autograd) assert supports_autograd.issubset(ops) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e078856c43d2ce..ebf835874c60ff 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -169,12 +169,12 @@ def f(x): return torch.tanh(x).sum() fx_f = make_fx(grad(f))(torch.randn(5)) - ops = set([i.target for i in fx_f.graph.nodes]) + ops = {i.target for i in fx_f.graph.nodes} self.assertEqual(torch.ops.aten.tanh_backward in ops, True) fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) - ops = set([i.target for i in fx_f.graph.nodes]) + ops = {i.target for i in fx_f.graph.nodes} self.assertEqual(torch.ops.aten.tanh_backward in ops, False) def test_nnc_jit(self, device): diff --git a/test/functorch/test_minifier.py b/test/functorch/test_minifier.py index 7ed13921d90770..9e6f495bcd4b64 100644 --- a/test/functorch/test_minifier.py +++ b/test/functorch/test_minifier.py @@ -18,7 +18,7 @@ def failing_f(x, y): failing_f = make_fx(failing_f)(*inps) def has_mul(fx_g, inps): - return (torch.ops.aten.mul.Tensor in set([i.target for i in fx_g.graph.nodes])) + return (torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes)) min_f, inps = minifier(failing_f, inps, has_mul) self.assertEqual(len(min_f.graph.nodes), 4) @@ -74,7 +74,7 @@ def f(a, b): inps = [torch.randn(3), torch.randn(3)] def has_add(fx_g, inps): - return (torch.ops.aten.add.Tensor in set([i.target for i in fx_g.graph.nodes])) + return (torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes)) failing_f = make_fx(f)(*inps) min_f, inps = minifier(failing_f, inps, has_add) diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index cdf2cca13671cf..cfe1460a01ac3d 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -114,7 +114,7 @@ def get_suggested_xfails(base, tests): tests = [test[len(base):] for test in tests if belongs_to_base(test, base)] - base_tests = set([remove_device_dtype(test) for test in tests]) + base_tests = {remove_device_dtype(test) for test in tests} tests = set(tests) for base in base_tests: cpu_variant = base + '_cpu_float32' diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 29f633c153fa87..980b76cf599789 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -226,7 +226,7 @@ def foo2(): self.checkScript(foo2, ()) def foo3(): - return list(list("abc")) + return list(list("abc")) # noqa: C414 self.checkScript(foo3, ()) FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph) diff --git a/test/mobile/model_test/gen_test_model.py b/test/mobile/model_test/gen_test_model.py index 370e8d08541f7b..7c6b780e8d6d4c 100644 --- a/test/mobile/model_test/gen_test_model.py +++ b/test/mobile/model_test/gen_test_model.py @@ -140,7 +140,7 @@ def calcOpsCoverage(ops): "_coverage": round(coverage, 2), "uncovered_ops": uncovered_ops_dict, "covered_ops": covered_ops_dict, - "all_generated_ops": sorted(list(all_generated_ops)), + "all_generated_ops": sorted(all_generated_ops), }, f, ) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index fe5e2411aa3838..50013fbc7dde59 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -40,7 +40,7 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): if hasattr(test_suite, "check_dtype"): options.check_dtype = test_suite.check_dtype - names = set([f.name for f in dataclasses.fields(options)]) + names = {f.name for f in dataclasses.fields(options)} keywords_to_pop = [] for k, v in kwargs.items(): if k in names: diff --git a/test/package/test_digraph.py b/test/package/test_digraph.py index 0ccc09bcf74c73..92f469868f7c9e 100644 --- a/test/package/test_digraph.py +++ b/test/package/test_digraph.py @@ -116,7 +116,7 @@ def test_all_paths(self): result = g.all_paths("1", "3") # to get rid of indeterminism - actual = set([i.strip("\n") for i in result.split(";")[2:-1]]) + actual = {i.strip("\n") for i in result.split(";")[2:-1]} expected = { '"2" -> "3"', '"1" -> "7"', diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index 7a5a631080f98a..a20a17d6637df9 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -365,10 +365,10 @@ def checkQuantized(model): # test one line API - out of place version base = AnnotatedSingleLayerLinearModel(qengine) base.qconfig = qconfig - keys_before = set(list(base.state_dict().keys())) + keys_before = set(base.state_dict().keys()) model = quantize(base, test_only_eval_fn, [self.calib_data]) checkQuantized(model) - keys_after = set(list(base.state_dict().keys())) + keys_after = set(base.state_dict().keys()) self.assertEqual(keys_before, keys_after) # simple check that nothing changed # in-place version @@ -1107,10 +1107,10 @@ def checkQuantized(model): # test one line API - out of place version base = SingleLayerLinearDynamicModel() - keys_before = set(list(base.state_dict().keys())) + keys_before = set(base.state_dict().keys()) model = quantize_dynamic(base, qconfig_dict) checkQuantized(model) - keys_after = set(list(base.state_dict().keys())) + keys_after = set(base.state_dict().keys()) self.assertEqual(keys_before, keys_after) # simple check that nothing changed # in-place version diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 6e367b0eb7fa53..e0a428a987b534 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -900,7 +900,7 @@ def test_constructor(self): model_report = ModelReport(model_prep, test_detector_set) # make sure internal valid reports matches - detector_name_set = set([detector.get_detector_name() for detector in test_detector_set]) + detector_name_set = {detector.get_detector_name() for detector in test_detector_set} self.assertEqual(model_report.get_desired_reports_names(), detector_name_set) # now attempt with no valid reports, should raise error @@ -1329,7 +1329,7 @@ def test_input_weight_equalization_determine_points(self): mods_to_check = set([nn.Linear, nn.Conv2d]) # get the set of all nodes in the graph their fqns - node_fqns = set([node.target for node in prepared_for_callibrate_model.graph.nodes]) + node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes} # there should be 4 node fqns that have the observer inserted correct_number_of_obs_inserted = 4 diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index bd0f1b1abfeba5..8330a6eb95656b 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -167,7 +167,7 @@ def check_torch_return_type(f, names): ret3 = meth(*op.input) check_namedtuple(ret3, op.names) - all_covered_operators = set([x for y in operators for x in y.operators]) + all_covered_operators = {x for y in operators for x in y.operators} self.assertEqual(all_operators_with_namedtuple_return, all_covered_operators, textwrap.dedent(''' The set of covered operators does not match the `all_operators_with_namedtuple_return` of diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 3563ac4d9556e7..7368a85c73cc93 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -579,7 +579,7 @@ def forward(mod_self, x): # noqa: B902 gm = make_fx(Emformer())(torch.randn(16, 1, 256)) - ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function']) + ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} self.assertEqual(len(ops), 2) diff --git a/test/test_sparse.py b/test/test_sparse.py index ddb8e9b3e11ba3..c466dd2e52a04c 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -264,7 +264,7 @@ def _test_coalesce(t): else: value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val - new_indices = sorted(list(value_map.keys())) + new_indices = sorted(value_map.keys()) _new_values = [value_map[idx] for idx in new_indices] if t._values().ndimension() < 2: new_values = t._values().new(_new_values) diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index 9ef0851aa33f67..64e901fe1d2364 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -130,13 +130,11 @@ def _module_dir(m: types.ModuleType): } # Include optimizer code for tracing -FILENAME_ALLOWLIST |= set( - [ - inspect.getfile(obj) - for obj in torch.optim.__dict__.values() - if inspect.isclass(obj) - ] -) +FILENAME_ALLOWLIST |= { + inspect.getfile(obj) + for obj in torch.optim.__dict__.values() + if inspect.isclass(obj) +} FILENAME_ALLOWLIST |= {torch.optim._functional.__file__} if HAS_PRIMS_REFS: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d261c139d8bd5f..d7513f393f6de0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -760,7 +760,7 @@ def enum_repr(value): def dict_param_key_ids(value): - return set([id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)]) + return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)} def dict_const_keys(value): @@ -771,7 +771,7 @@ def dict_const_keys_repr(const_keys): if any(isinstance(k, enum.Enum) for k in const_keys): # To workaround repr(Enum) returning invalid global reference before python 3.11 # by calling enum_repr and removing quotes to render enum in guard code. - const_keys_str = f"{set([enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys])}".replace( + const_keys_str = f"{set(enum_repr(k) if isinstance(k, enum.Enum) else repr(k) for k in const_keys)}".replace( "'", "" ) else: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index eba6589caab7d4..67a0a534ffb94b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -304,17 +304,12 @@ def index_source(key): else: return key - result = dict( - [ - ( - k, - VariableBuilder( - self.tx, GetItemSource(self.get_source(), index_source(k)) - )(value[k]).add_guards(guards), - ) - for k in value.keys() - ] - ) + result = { + k: VariableBuilder( + self.tx, GetItemSource(self.get_source(), index_source(k)) + )(value[k]).add_guards(guards) + for k in value.keys() + } if istype(value, collections.defaultdict): result = DefaultDictVariable( diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 63562895d41e8d..80c024740a3b69 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -393,7 +393,7 @@ def is_tensor_node(x): for node in joint_module.graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") ) - ops_ignored = joint_module_ops - set([str(i) for i in recomputable_ops]) + ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} print("Ops banned from rematerialization: ", ops_ignored) print() @@ -522,8 +522,8 @@ def get_node_weight(node) -> int: joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) if AOT_PARTITIONER_DEBUG: print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9) - fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function']) - bw_module_nodes = set([node.name for node in bw_module.graph.nodes if node.op == 'call_function']) + fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'} + bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'} remat_nodes = fw_module_nodes & bw_module_nodes counts = defaultdict(int) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index cfbfa8e2722dad..659edeb3b9b7f7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -535,9 +535,7 @@ def get_read_write_buffers_sizes(node): writes = set(dep.name for dep in node.read_writes.writes) def is_materialized(buf): - buf_uses = set( - [user.node for user in scheduler.name_to_node[buf].users] - ) + buf_uses = {user.node for user in scheduler.name_to_node[buf].users} return len(buf_uses - set(node.snodes)) > 0 if isinstance(node, FusedSchedulerNode): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 8c66bbc31957a6..f36af67a356c50 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -344,7 +344,9 @@ def fresh_inductor_cache(cache_entries=None): def argsort(seq): # preserve original order for equal strides - return list(reversed(sorted(range(len(seq)), key=seq.__getitem__, reverse=True))) + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 @functools.lru_cache(8) diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index ee96dd4bf5a9c9..27a9aa3d05ba35 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -120,7 +120,7 @@ def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBas # keep the reports private so they can't be modified self._desired_report_detectors = desired_report_detectors - self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors]) + self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors} # keep a mapping of desired reports to observers of interest # this is to get the readings, and to remove them, can create a large set diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index c05413c9951661..736984f5c71752 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1598,7 +1598,7 @@ def _all_gather_optim_state( gathered_state: Dict[str, Any] = {} all_tensor_states = sorted( - set([n for state in object_list for n in state.tensors.keys()]) + {n for state in object_list for n in state.tensors.keys()} ) empty_ranks: Set[int] = set() for name in all_tensor_states: diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 73e0ed6de7087e..a88dc3e90adcea 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -264,7 +264,7 @@ def __init__( for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) if not name.startswith("_") and callable(value) } - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) # Python modules to apply autowrap to at the start, in addition to # modules we see while tracing diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 6ac12e42959bbc..962e067c9fcbec 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3611,8 +3611,8 @@ def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs): torch = kwargs.get('torch', globals()['torch']) dtype = kwargs.get('dtype', torch.double) device = kwargs.get('device', 'cpu') - data = dict([((i, i), float(i + 1) / matrix_size) - for i in range(matrix_size)]) + data = {(i, i): float(i + 1) / matrix_size + for i in range(matrix_size)} def multiply(data, N, i, j, cs, sn, left=True): diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 5768ff2facb9a5..a8dc476254cf4f 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -377,29 +377,25 @@ def gen_dispatchkey_nativefunc_headers( # Convert to a set first to remove duplicate kernel names. # Backends are allowed to repeat kernel names; only generate the declaration once! # Sort for deterministic output. - backend_declarations = list( - sorted( - set( - concatMap( - lambda f: dest.compute_native_function_declaration( - f, backend_indices[backend_dispatch_key] - ), - grouped_native_functions, - ) + backend_declarations = sorted( + set( + concatMap( + lambda f: dest.compute_native_function_declaration( + f, backend_indices[backend_dispatch_key] + ), + grouped_native_functions, ) ) ) - autograd_declarations = list( - sorted( - set( - concatMap( - lambda f: [] - if autograd_dispatch_key is None - else dest.compute_native_function_declaration( - f, backend_indices[autograd_dispatch_key] - ), - grouped_native_functions, - ) + autograd_declarations = sorted( + set( + concatMap( + lambda f: [] + if autograd_dispatch_key is None + else dest.compute_native_function_declaration( + f, backend_indices[autograd_dispatch_key] + ), + grouped_native_functions, ) ) ) diff --git a/torchgen/model.py b/torchgen/model.py index 6e34f85b679f77..a1efbdf459bd46 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1058,7 +1058,7 @@ def __post_init__(self) -> None: for f in self.functions(): expected_generated_fns.update(str(op) for op in f.autogen) expected_generated_fns_str = ", ".join( - str(x) for x in sorted(list(expected_generated_fns)) + str(x) for x in sorted(expected_generated_fns) ) if len(expected_generated_fns) == 0 and len(generated_fns) > 0: raise RuntimeError( diff --git a/torchgen/selective_build/selector.py b/torchgen/selective_build/selector.py index 32f0f9e219cafb..03e638c179f530 100644 --- a/torchgen/selective_build/selector.py +++ b/torchgen/selective_build/selector.py @@ -231,7 +231,7 @@ def to_dict(self) -> Dict[str, object]: ret["debug_info"] = sorted(self._debug_info) ret["kernel_metadata"] = { - k: sorted(list(v)) for (k, v) in self.kernel_metadata.items() + k: sorted(v) for (k, v) in self.kernel_metadata.items() } ret["custom_classes"] = sorted(self.custom_classes)