Skip to content

Commit

Permalink
update code with new black version
Browse files Browse the repository at this point in the history
  • Loading branch information
clavedeluna committed Jan 29, 2024
1 parent 6964366 commit cb87a2b
Show file tree
Hide file tree
Showing 22 changed files with 119 additions and 58 deletions.
8 changes: 5 additions & 3 deletions src/codemodder/codemods/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def __new__(cls, *args, **kwargs):

return cls.codemod_base(
metadata=cls.metadata,
detector=SemgrepRuleDetector(cls.detector_pattern)
if getattr(cls, "detector_pattern", None)
else None,
detector=(
SemgrepRuleDetector(cls.detector_pattern)
if getattr(cls, "detector_pattern", None)
else None
),
# This allows the transformer to inherit all the methods of the class itself
transformer=LibcstTransformerPipeline(cls),
)
6 changes: 2 additions & 4 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,11 @@ def __init__(

@property
@abstractmethod
def origin(self) -> str:
...
def origin(self) -> str: ...

@property
@abstractmethod
def docs_module_path(self) -> str:
...
def docs_module_path(self) -> str: ...

@property
def name(self) -> str:
Expand Down
3 changes: 1 addition & 2 deletions src/codemodder/codemods/base_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ def apply(
codemod_id: str,
context: CodemodExecutionContext,
files_to_analyze: list[Path],
) -> ResultSet:
...
) -> ResultSet: ...
14 changes: 8 additions & 6 deletions src/codemodder/codemods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def infer_expression_type(node: cst.BaseExpression) -> Optional[BaseType]:
"""
# The current implementation covers some common cases and is in no way complete
match node:
case cst.Integer() | cst.Imaginary() | cst.Float() | cst.Call(
func=cst.Name("int")
) | cst.Call(func=cst.Name("float")) | cst.Call(
func=cst.Name("abs")
) | cst.Call(
func=cst.Name("len")
case (
cst.Integer()
| cst.Imaginary()
| cst.Float()
| cst.Call(func=cst.Name("int"))
| cst.Call(func=cst.Name("float"))
| cst.Call(func=cst.Name("abs"))
| cst.Call(func=cst.Name("len"))
):
return BaseType.NUMBER
case cst.Call(name=cst.Name("list")) | cst.List() | cst.ListComp():
Expand Down
18 changes: 11 additions & 7 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,11 @@ def is_value_of_assignment(
"""
parent = self.get_metadata(ParentNodeProvider, expr)
match parent:
case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem(
item=value
) | cst.NamedExpr(
value=value
case (
cst.AnnAssign(value=value)
| cst.Assign(value=value)
| cst.WithItem(item=value)
| cst.NamedExpr(value=value)
) if expr == value: # type: ignore
return parent
return None
Expand Down Expand Up @@ -448,9 +449,12 @@ class NameAndAncestorResolutionMixin(NameResolutionMixin, AncestorPatternsMixin)

def extract_value(self, node: cst.AnnAssign | cst.Assign | cst.WithItem):
match node:
case cst.AnnAssign(value=value) | cst.Assign(value=value) | cst.WithItem(
item=value
) | cst.NamedExpr(value=value):
case (
cst.AnnAssign(value=value)
| cst.Assign(value=value)
| cst.WithItem(item=value)
| cst.NamedExpr(value=value)
):
return value
return None

Expand Down
14 changes: 8 additions & 6 deletions src/core_codemods/fix_mutable_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def _gather_and_update_params(
)
add_annotation = add_annotation or annotation is not None
updated_params.append(
updated.with_changes(
default=cst.Name("None"),
annotation=annotation,
)
if needs_update
else updated,
(
updated.with_changes(
default=cst.Name("None"),
annotation=annotation,
)
if needs_update
else updated
),
)

return updated_params, new_var_decls, add_annotation
Expand Down
13 changes: 12 additions & 1 deletion src/core_codemods/literal_or_new_object_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@ class LiteralOrNewObjectIdentity(SimpleCodemod, NameAndAncestorResolutionMixin):

def _is_object_creation_or_literal(self, node: cst.BaseExpression):
match node:
case cst.List() | cst.Dict() | cst.Tuple() | cst.Set() | cst.Integer() | cst.Float() | cst.Imaginary() | cst.SimpleString() | cst.ConcatenatedString() | cst.FormattedString():
case (
cst.List()
| cst.Dict()
| cst.Tuple()
| cst.Set()
| cst.Integer()
| cst.Float()
| cst.Imaginary()
| cst.SimpleString()
| cst.ConcatenatedString()
| cst.FormattedString()
):
return True
case cst.Call(func=cst.Name() as name):
return self.is_builtin_function(node) and name.value in (
Expand Down
14 changes: 8 additions & 6 deletions src/core_codemods/refactor/refactor_new_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ def leave_ImportFrom(self, original: cst.ImportFrom, updated: cst.ImportFrom):

def leave_ClassDef(self, original: cst.ClassDef, new: cst.ClassDef) -> cst.ClassDef:
new_bases: list[cst.Arg] = [
base.with_changes(value=cst.Name(self.new_api_class))
if self.find_base_name(base.value)
in (
"codemodder.codemods.api.BaseCodemod",
"codemodder.codemods.api.SemgrepCodemod",
(
base.with_changes(value=cst.Name(self.new_api_class))
if self.find_base_name(base.value)
in (
"codemodder.codemods.api.BaseCodemod",
"codemodder.codemods.api.SemgrepCodemod",
)
else base
)
else base
for base in original.bases
]

Expand Down
5 changes: 4 additions & 1 deletion src/core_codemods/remove_module_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def leave_Global(
self,
original_node: cst.Global,
updated_node: cst.Global,
) -> Union[cst.Global, cst.RemovalSentinel,]:
) -> Union[
cst.Global,
cst.RemovalSentinel,
]:
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
):
Expand Down
6 changes: 3 additions & 3 deletions src/core_codemods/sql_parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ def leave_Call(self, original_node: cst.Call) -> None:
first_arg.value.visit(query_visitor)
for expr in query_visitor.leaves:
match expr:
case cst.SimpleString() | cst.FormattedStringText() if self._has_keyword(
expr.value
):
case (
cst.SimpleString() | cst.FormattedStringText()
) if self._has_keyword(expr.value):
self.calls[original_node] = query_visitor.leaves


Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_combine_startswith_endswith.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def test_no_change(self, tmpdir, code):
self.run_and_assert(tmpdir, code, code)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
x = "foo"
x.startswith("foo") or x.startswith("f")
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_enable_jinja2_autoescape.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def test_aiohttp_import_alias(self, tmpdir):
self.run_and_assert(tmpdir, input_code, expected_output)

def test_aiohttp_import_alias_no_change(self, tmpdir):
expected_output = input_code = """
expected_output = (
input_code
) = """
from aiohttp_jinja2 import foo as setup
setup_jinja2(app)
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_exception_without_raise.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_raised_exception(self, tmpdir):
self.run_and_assert(tmpdir, dedent(input_code), dedent(input_code))

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
print(1)
ValueError("Bad value!")
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_fix_assert_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def test_no_change(self, tmpdir, code):
self.run_and_assert(tmpdir, code, code)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
assert (1, 2)
"""
lines_to_exclude = [1]
Expand Down
8 changes: 6 additions & 2 deletions tests/codemods/test_fix_deprecated_abstractproperty.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def foo(self):
self.run_and_assert(tmpdir, original_code, new_code)

def test_different_abstractproperty(self, tmpdir):
new_code = original_code = """
new_code = (
original_code
) = """
from xyz import abstractproperty
class A:
Expand Down Expand Up @@ -123,7 +125,9 @@ def foo(self):
self.run_and_assert(tmpdir, original_code, new_code)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
import abc
class A:
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_fix_empty_sequence_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def test_no_change(self, tmpdir, code):
self.run_and_assert(tmpdir, code, code)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
x = [1]
if x != []:
pass
Expand Down
8 changes: 6 additions & 2 deletions tests/codemods/test_harden_pyyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def test_import_alias(self, tmpdir):
self.run_and_assert(tmpdir, input_code, expected)

def test_preserve_custom_loader(self, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import yaml
from custom import CustomLoader
Expand All @@ -73,7 +75,9 @@ def test_preserve_custom_loader(self, tmpdir):
self.run_and_assert(tmpdir, input_code, expected)

def test_preserve_custom_loader_kwarg(self, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import yaml
from custom import CustomLoader
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_remove_debug_breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def something():
self.run_and_assert(tmpdir, input_code, expected)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
x = "foo"
breakpoint()
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_remove_unnecessary_f_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def test_change(self, tmpdir):
self.run_and_assert(tmpdir, before, after, num_changes=3)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
bad: str = f"bad" + "bad"
"""
lines_to_exclude = [1]
Expand Down
4 changes: 3 additions & 1 deletion tests/codemods/test_subprocess_shell_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_shell_False(self, tmpdir, func):
self.run_and_assert(tmpdir, input_code, input_code)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
import subprocess
subprocess.run(args, shell=True)
"""
Expand Down
16 changes: 12 additions & 4 deletions tests/codemods/test_url_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def test_requests_with_alias(self, add_dependency, tmpdir):
add_dependency.assert_called_once_with(Security)

def test_ignore_hardcoded(self, _, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import requests
requests.get("www.google.com")
Expand All @@ -197,7 +199,9 @@ def test_ignore_hardcoded(self, _, tmpdir):
self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_global_variable(self, _, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import requests
URL = "www.google.com"
Expand All @@ -207,7 +211,9 @@ def test_ignore_hardcoded_from_global_variable(self, _, tmpdir):
self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_local_variable(self, _, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import requests
def foo():
Expand All @@ -218,7 +224,9 @@ def foo():
self.run_and_assert(tmpdir, input_code, expected)

def test_ignore_hardcoded_from_local_variable_transitive(self, _, tmpdir):
expected = input_code = """
expected = (
input_code
) = """
import requests
def foo():
Expand Down
12 changes: 9 additions & 3 deletions tests/codemods/test_use_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@ def test_list_comprehension(self, tmpdir, func):
self.run_and_assert(tmpdir, original_code, new_code)

def test_not_special_builtin(self, tmpdir):
expected = original_code = """
expected = (
original_code
) = """
x = some([i for i in range(10)])
"""
self.run_and_assert(tmpdir, original_code, expected)

def test_not_global_function(self, tmpdir):
expected = original_code = """
expected = (
original_code
) = """
from foo import any
x = any([i for i in range(10)])
"""
self.run_and_assert(tmpdir, original_code, expected)

def test_exclude_line(self, tmpdir):
input_code = expected = """\
input_code = (
expected
) = """\
x = any([i for i in range(10)])
"""
lines_to_exclude = [1]
Expand Down

0 comments on commit cb87a2b

Please sign in to comment.