diff --git a/src/codemodder/codemods/utils_mixin.py b/src/codemodder/codemods/utils_mixin.py index 0646c935..787577b0 100644 --- a/src/codemodder/codemods/utils_mixin.py +++ b/src/codemodder/codemods/utils_mixin.py @@ -86,19 +86,18 @@ def get_imported_prefix( return (import_node, alias) return None - def get_aliased_prefix_name(self, node: cst.CSTNode, name: str): + def get_aliased_prefix_name(self, node: cst.CSTNode, name: str) -> Optional[str]: """ Returns the alias of name if name is imported and used as a prefix for this node. """ maybe_import = self.get_imported_prefix(node) maybe_name = None - if maybe_import: - imp, ia = maybe_import - match imp: - case cst.Import(): - imp_name = get_full_name_for_node(ia.name) - if imp_name == name and ia.asname: - maybe_name = ia.asname.name.value + if maybe_import and matchers.matches(maybe_import[0], matchers.Import()): + _, ia = maybe_import + imp_name = get_full_name_for_node(ia.name) + if imp_name == name and ia.asname: + # AsName is always a Name for ImportAlias + maybe_name = ia.asname.name.value return maybe_name def find_assignments( diff --git a/src/core_codemods/harden_pyyaml.py b/src/core_codemods/harden_pyyaml.py index 792a4d31..5638eccd 100644 --- a/src/core_codemods/harden_pyyaml.py +++ b/src/core_codemods/harden_pyyaml.py @@ -50,7 +50,7 @@ def rule(cls): def on_result_found(self, original_node, updated_node): maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) maybe_name = maybe_name or self._module_name - if maybe_name and maybe_name == self._module_name: + if maybe_name == self._module_name: self.add_needed_import(self._module_name) new_args = [ *updated_node.args[:1], diff --git a/src/core_codemods/tempfile_mktemp.py b/src/core_codemods/tempfile_mktemp.py index d45637e5..8de39ac1 100644 --- a/src/core_codemods/tempfile_mktemp.py +++ b/src/core_codemods/tempfile_mktemp.py @@ -31,7 +31,7 @@ def rule(cls): def on_result_found(self, original_node, updated_node): maybe_name = self.get_aliased_prefix_name(original_node, self._module_name) maybe_name = maybe_name or self._module_name - if maybe_name and maybe_name == self._module_name: + if maybe_name == self._module_name: self.add_needed_import(self._module_name) self.remove_unused_import(original_node) return self.update_call_target(updated_node, maybe_name, "mkstemp") diff --git a/src/core_codemods/upgrade_sslcontext_minimum_version.py b/src/core_codemods/upgrade_sslcontext_minimum_version.py index ed8c5e7c..85ee1a85 100644 --- a/src/core_codemods/upgrade_sslcontext_minimum_version.py +++ b/src/core_codemods/upgrade_sslcontext_minimum_version.py @@ -52,7 +52,7 @@ def on_result_found(self, original_node, updated_node): original_node.value, self._module_name ) maybe_name = maybe_name or self._module_name - if maybe_name and maybe_name == self._module_name: + if maybe_name == self._module_name: self.add_needed_import(self._module_name) self.remove_unused_import(original_node) return self.update_assign_rhs(updated_node, f"{maybe_name}.TLSVersion.TLSv1_2") diff --git a/tests/codemods/test_harden_ruamel.py b/tests/codemods/test_harden_ruamel.py index 903dc9a5..2a388a4d 100644 --- a/tests/codemods/test_harden_ruamel.py +++ b/tests/codemods/test_harden_ruamel.py @@ -51,14 +51,13 @@ def test_unsafe_import(self, tmpdir, loader): """ self.run_and_assert(tmpdir, input_code, expected) - @pytest.mark.skip() @pytest.mark.parametrize("loader", ["YAML(typ='base')", "YAML(typ='unsafe')"]) def test_import_alias(self, tmpdir, loader): input_code = f"""from ruamel import yaml as yam serializer = yam.{loader} """ - expected = """import ruamel + expected = """from ruamel import yaml as yam serializer = yam.YAML(typ="safe") """ diff --git a/tests/codemods/test_https_connection.py b/tests/codemods/test_https_connection.py index 117a6ef0..00558658 100644 --- a/tests/codemods/test_https_connection.py +++ b/tests/codemods/test_https_connection.py @@ -21,6 +21,18 @@ def test_simple(self, tmpdir): after = r"""import urllib3 urllib3.HTTPSConnectionPool("localhost", "80") +""" + self.run_and_assert(tmpdir, before, after) + assert len(self.file_context.codemod_changes) == 1 + + def test_module_alias(self, tmpdir): + before = r"""import urllib3 as module + +module.HTTPConnectionPool("localhost", "80") +""" + after = r"""import urllib3 as module + +module.HTTPSConnectionPool("localhost", "80") """ self.run_and_assert(tmpdir, before, after) assert len(self.file_context.codemod_changes) == 1 diff --git a/tests/codemods/test_tempfile_mktemp.py b/tests/codemods/test_tempfile_mktemp.py index 6dcca028..b4bb69e5 100644 --- a/tests/codemods/test_tempfile_mktemp.py +++ b/tests/codemods/test_tempfile_mktemp.py @@ -1,4 +1,3 @@ -import pytest # pylint: disable=unused-import from core_codemods.tempfile_mktemp import TempfileMktemp from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest @@ -60,6 +59,19 @@ def test_import_alias(self, tmpdir): _tempfile.mkstemp() var = "hello" +""" + self.run_and_assert(tmpdir, input_code, expected_output) + + def test_import_method_alias(self, tmpdir): + input_code = """from tempfile import mktemp as get_temp_file + +get_temp_file() +var = "hello" +""" + expected_output = """import tempfile + +tempfile.mkstemp() +var = "hello" """ self.run_and_assert(tmpdir, input_code, expected_output) diff --git a/tests/codemods/test_use_defused_xml.py b/tests/codemods/test_use_defused_xml.py index a86d023a..1df986a8 100644 --- a/tests/codemods/test_use_defused_xml.py +++ b/tests/codemods/test_use_defused_xml.py @@ -32,6 +32,26 @@ def test_etree_simple_call(self, tmpdir, module, method): self.run_and_assert(tmpdir, original_code, new_code) self.assert_dependency(DefusedXML) + @pytest.mark.parametrize("method", ETREE_METHODS) + def test_etree_module_alias(self, tmpdir, method): + original_code = f""" + import xml.etree.ElementTree as alias + import xml.etree.cElementTree as calias + + et = alias.{method}('some.xml') + cet = calias.{method}('some.xml') + """ + + new_code = f""" + import defusedxml.ElementTree + + et = defusedxml.ElementTree.{method}('some.xml') + cet = defusedxml.ElementTree.{method}('some.xml') + """ + + self.run_and_assert(tmpdir, original_code, new_code) + self.assert_dependency(DefusedXML) + @pytest.mark.parametrize("method", ETREE_METHODS) @pytest.mark.parametrize("module", ["ElementTree", "cElementTree"]) def test_etree_attribute_call(self, tmpdir, module, method):