From ad97dd1207560d0c03c5a645b574369e6c695626 Mon Sep 17 00:00:00 2001 From: clavedeluna Date: Fri, 27 Oct 2023 08:04:15 -0300 Subject: [PATCH] handle no args and use api --- src/codemodder/codemods/api/helpers.py | 13 +- .../semgrep/upgrade_sslcontext_tls.yaml | 22 --- src/core_codemods/upgrade_sslcontext_tls.py | 149 +++++++----------- tests/codemods/test_upgrade_sslcontext_tls.py | 9 +- 4 files changed, 72 insertions(+), 121 deletions(-) delete mode 100644 src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml diff --git a/src/codemodder/codemods/api/helpers.py b/src/codemodder/codemods/api/helpers.py index 930738109..590725c11 100644 --- a/src/codemodder/codemods/api/helpers.py +++ b/src/codemodder/codemods/api/helpers.py @@ -67,7 +67,7 @@ def replace_args(self, original_node, args_info): for arg in original_node.args: arg_name, replacement_val, idx = _match_with_existing_arg(arg, args_info) if arg_name is not None: - new = self.make_new_arg(arg_name, replacement_val, arg) + new = self.make_new_arg(replacement_val, arg_name, arg) del args_info[idx] else: new = arg @@ -75,12 +75,19 @@ def replace_args(self, original_node, args_info): for arg_name, replacement_val, add_if_missing in args_info: if add_if_missing: - new = self.make_new_arg(arg_name, replacement_val) + new = self.make_new_arg(replacement_val, arg_name) new_args.append(new) return new_args - def make_new_arg(self, name, value, existing_arg=None): + def make_new_arg(self, value, name=None, existing_arg=None): + if name is None: + # Make a positional argument + return cst.Arg( + value=cst.parse_expression(value), + ) + + # make a keyword argument equal = ( existing_arg.equal if existing_arg diff --git a/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml b/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml deleted file mode 100644 index c165b506c..000000000 --- a/src/core_codemods/semgrep/upgrade_sslcontext_tls.yaml +++ /dev/null @@ -1,22 +0,0 @@ -rules: - - id: upgrade-sslcontext-tls - message: Upgrade weak SSL/TLS protocol version in SSLContext - severity: WARNING - languages: - - python - patterns: - - pattern-either: - - pattern: ssl.SSLContext() - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv3,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1_1,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1_1,...) - - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLS,...) - - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLS,...) - - pattern-inside: | - import ssl - ... diff --git a/src/core_codemods/upgrade_sslcontext_tls.py b/src/core_codemods/upgrade_sslcontext_tls.py index 37d4ea5ed..d8538b68b 100644 --- a/src/core_codemods/upgrade_sslcontext_tls.py +++ b/src/core_codemods/upgrade_sslcontext_tls.py @@ -1,36 +1,24 @@ -import libcst as cst -from libcst.codemod import CodemodContext -from codemodder.codemods.base_visitor import BaseTransformer -from codemodder.codemods.base_codemod import ( - SemgrepCodemod, - CodemodMetadata, - ReviewGuidance, -) -from codemodder.change import Change -from codemodder.file_context import FileContext +from codemodder.codemods.base_codemod import ReviewGuidance +from codemodder.codemods.api import SemgrepCodemod +from codemodder.codemods.api.helpers import NewArg -class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer): - METADATA = CodemodMetadata( - DESCRIPTION="Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones.", - NAME="upgrade-sslcontext-tls", - REVIEW_GUIDANCE=ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW, - REFERENCES=[ - { - "url": "https://docs.python.org/3/library/ssl.html#security-considerations", - "description": "", - }, - {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, - { - "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", - "description": "", - }, - ], - ) +class UpgradeSSLContextTLS(SemgrepCodemod): + NAME = "upgrade-sslcontext-tls" + REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW SUMMARY = "Upgrade TLS Version In SSLContext" + DESCRIPTION = "Replaces known insecure TLS/SSL protocol versions in SSLContext with secure ones." CHANGE_DESCRIPTION = "Upgrade to use a safe version of TLS in SSLContext" - YAML_FILES = [ - "upgrade_sslcontext_tls.yaml", + REFERENCES = [ + { + "url": "https://docs.python.org/3/library/ssl.html#security-considerations", + "description": "", + }, + {"url": "https://datatracker.ietf.org/doc/rfc8996/", "description": ""}, + { + "url": "https://www.digicert.com/blog/depreciating-tls-1-0-and-1-1", + "description": "", + }, ] # TODO: in the majority of cases, using PROTOCOL_TLS_CLIENT will be the @@ -38,70 +26,47 @@ class UpgradeSSLContextTLS(SemgrepCodemod, BaseTransformer): # PROTOCOL_TLS_SERVER instead. We currently don't have a good way to handle # this. Eventually, when the platform supports parameters, we want to # revisit this to provide PROTOCOL_TLS_SERVER as an alternative fix. - SAFE_TLS_PROTOCOL_VERSION = "PROTOCOL_TLS_CLIENT" - PROTOCOL_ARG_INDEX = 0 - PROTOCOL_KWARG_NAME = "protocol" + SAFE_TLS_PROTOCOL_VERSION = "ssl.PROTOCOL_TLS_CLIENT" + # PROTOCOL_ARG_INDEX = 0 + # PROTOCOL_KWARG_NAME = "protocol" - def __init__(self, codemod_context: CodemodContext, file_context: FileContext): - SemgrepCodemod.__init__(self, file_context) - BaseTransformer.__init__(self, codemod_context, self._results) + @classmethod + def rule(cls): + return """ + rules: + - patterns: + - pattern-inside: | + import ssl + ... + - pattern-either: + - pattern: ssl.SSLContext() + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv2,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv2,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_SSLv3,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_SSLv3,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLSv1_1,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLSv1_1,...) + - pattern: ssl.SSLContext(...,ssl.PROTOCOL_TLS,...) + - pattern: ssl.SSLContext(...,protocol=ssl.PROTOCOL_TLS,...) + """ - # TODO: apply unused import remover + def on_result_found(self, original_node, updated_node): + self.remove_unused_import(original_node) + self.add_needed_import("ssl") - def update_arg(self, arg: cst.Arg) -> cst.Arg: - new_name = cst.Name(self.SAFE_TLS_PROTOCOL_VERSION) - # TODO: are there other cases to handle here? - new_value = ( - arg.value.with_changes(attr=new_name) - if isinstance(arg.value, cst.Attribute) - else new_name - ) - return arg.with_changes(value=new_value) - - def leave_Call(self, original_node: cst.Call, updated_node: cst.Arg): - pos_to_match = self.get_metadata(self.METADATA_DEPENDENCIES[0], original_node) - if self.filter_by_result( - pos_to_match - ) and self.filter_by_path_includes_or_excludes(pos_to_match): - line_number = pos_to_match.start.line - self.file_context.codemod_changes.append( - Change(line_number, self.CHANGE_DESCRIPTION) - ) - - if not updated_node.args: - return updated_node.with_changes( - args=[ - self.make_new_arg( - self.PROTOCOL_KWARG_NAME, - f"ssl.{self.SAFE_TLS_PROTOCOL_VERSION}", - ) - ] - ) - - return updated_node.with_changes( - args=[ - self.update_arg(arg) - if idx == self.PROTOCOL_ARG_INDEX - or (arg.keyword and arg.keyword.value == self.PROTOCOL_KWARG_NAME) - else arg - for idx, arg in enumerate(original_node.args) - ] - ) - - return updated_node - - # dedupe with api - def make_new_arg(self, name, value, existing_arg=None): - equal = ( - existing_arg.equal - if existing_arg - else cst.AssignEqual( - whitespace_before=cst.SimpleWhitespace(""), - whitespace_after=cst.SimpleWhitespace(""), + if len((args := original_node.args)) == 1 and args[0].keyword is None: + new_args = [self.make_new_arg(self.SAFE_TLS_PROTOCOL_VERSION)] + else: + new_args = self.replace_args( + original_node, + [ + NewArg( + name="protocol", + value=self.SAFE_TLS_PROTOCOL_VERSION, + add_if_missing=True, + ) + ], ) - ) - return cst.Arg( - keyword=cst.parse_expression(name), - value=cst.parse_expression(value), - equal=equal, - ) + return self.update_arg_target(updated_node, new_args) diff --git a/tests/codemods/test_upgrade_sslcontext_tls.py b/tests/codemods/test_upgrade_sslcontext_tls.py index 4a2bc3c2c..caf5cfe4d 100644 --- a/tests/codemods/test_upgrade_sslcontext_tls.py +++ b/tests/codemods/test_upgrade_sslcontext_tls.py @@ -89,8 +89,9 @@ def test_upgrade_protocol_with_kwarg_import_alias(self, tmpdir, protocol): var = "hello" """ expected_output = """import ssl as whatever +import ssl -context = whatever.SSLContext(protocol=whatever.PROTOCOL_TLS_CLIENT) +context = whatever.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) var = "hello" """ self.run_and_assert(tmpdir, input_code, expected_output) @@ -116,7 +117,7 @@ def test_upgrade_protocol_in_expression_do_not_modify(self, tmpdir): self.run_and_assert(tmpdir, input_code, expected_output) def test_import_no_protocol(self, tmpdir): - input_code = f"""import ssl + input_code = """import ssl context = ssl.SSLContext() """ expected_output = """import ssl @@ -124,13 +125,13 @@ def test_import_no_protocol(self, tmpdir): """ self.run_and_assert(tmpdir, input_code, expected_output) - @pytest.mark.skip() def test_from_import_no_protocol(self, tmpdir): - input_code = f"""from ssl import SSLContext + input_code = """from ssl import SSLContext SSLContext() """ expected_output = """from ssl import SSLContext import ssl + SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) """ self.run_and_assert(tmpdir, input_code, expected_output)