diff --git a/src/codemodder/codemodder.py b/src/codemodder/codemodder.py index 2094f405..d7b5cab4 100644 --- a/src/codemodder/codemodder.py +++ b/src/codemodder/codemodder.py @@ -126,7 +126,7 @@ def run( verbose: bool = False, log_format: OutputFormat = OutputFormat.JSON, project_name: str | None = None, - tool_result_files_map: DefaultDict[str, list[str]] = defaultdict(list), + tool_result_files_map: DefaultDict[str, list[Path]] = defaultdict(list), path_include: list[str] | None = None, path_exclude: list[str] | None = None, codemod_include: list[str] | None = None, @@ -240,8 +240,7 @@ def _run_cli(original_args) -> int: return 1 try: - # TODO: this should be dict[str, list[Path]] - tool_result_files_map: DefaultDict[str, list[str]] = detect_sarif_tools( + tool_result_files_map: DefaultDict[str, list[Path]] = detect_sarif_tools( [Path(name) for name in argv.sarif or []] ) except (DuplicateToolError, FileNotFoundError) as err: diff --git a/src/codemodder/codemods/test/utils.py b/src/codemodder/codemods/test/utils.py index 923f16bb..78fb9045 100644 --- a/src/codemodder/codemods/test/utils.py +++ b/src/codemodder/codemods/test/utils.py @@ -193,7 +193,7 @@ def run_and_assert( directory=root, dry_run=False, verbose=False, - tool_result_files_map={self.tool: [str(tmp_results_file_path)]}, + tool_result_files_map={self.tool: [tmp_results_file_path]}, registry=mock.MagicMock(), providers=load_providers(), repo_manager=mock.MagicMock(), diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 11ab540c..a951be25 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -49,7 +49,7 @@ class CodemodExecutionContext: path_include: list[str] path_exclude: list[str] max_workers: int = 1 - tool_result_files_map: dict[str, list[str]] + tool_result_files_map: dict[str, list[Path]] semgrep_prefilter_results: ResultSet | None = None openai_llm_client: OpenAI | None = None azure_llama_llm_client: ChatCompletionsClient | None = None @@ -64,7 +64,7 @@ def __init__( repo_manager: PythonRepoManager | None = None, path_include: list[str] | None = None, path_exclude: list[str] | None = None, - tool_result_files_map: dict[str, list[str]] | None = None, + tool_result_files_map: dict[str, list[Path]] | None = None, max_workers: int = 1, ): self.directory = directory diff --git a/src/codemodder/sarifs.py b/src/codemodder/sarifs.py index 37f4261b..f155e2b7 100644 --- a/src/codemodder/sarifs.py +++ b/src/codemodder/sarifs.py @@ -18,18 +18,27 @@ def detect(cls, run_data: dict) -> bool: class DuplicateToolError(ValueError): ... -def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[str]]: - results: DefaultDict[str, list[str]] = defaultdict(list) +def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[Path]]: + results: DefaultDict[str, list[Path]] = defaultdict(list) logger.debug("loading registered SARIF tool detectors") detectors = { ent.name: ent.load() for ent in entry_points().select(group="sarif_detectors") } for fname in filenames: - data = json.loads(fname.read_text(encoding="utf-8-sig")) + try: + data = json.loads(fname.read_text(encoding="utf-8-sig")) + except json.JSONDecodeError: + logger.exception("Malformed JSON file: %s", fname) + raise for name, det in detectors.items(): - # TODO: handle malformed sarif? - for run in data["runs"]: + try: + runs = data["runs"] + except KeyError: + logger.exception("Sarif file without `runs` data: %s", fname) + raise + + for run in runs: try: if det.detect(run): logger.debug("detected %s sarif: %s", name, fname) @@ -39,7 +48,7 @@ def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[str]]: raise DuplicateToolError( f"duplicate tool sarif detected: {name}" ) - results[name].append(str(fname)) + results[name].append(Path(fname)) except DuplicateToolError as err: raise err except (KeyError, AttributeError, ValueError): diff --git a/tests/test_sarif_processing.py b/tests/test_sarif_processing.py index 4a055881..6b90426e 100644 --- a/tests/test_sarif_processing.py +++ b/tests/test_sarif_processing.py @@ -45,6 +45,7 @@ def test_detect_sarif_with_bom_encoding(self, tmpdir): results = detect_sarif_tools([sarif_file_bom]) assert len(results) == 1 + assert isinstance(results["semgrep"][0], Path) @pytest.mark.parametrize("truncate", [True, False]) def test_results_by_rule_id(self, truncate): @@ -111,6 +112,32 @@ def test_two_sarifs_same_tool(self): detect_sarif_tools([Path("tests/samples/webgoat_v8.2.0_codeql.sarif")] * 2) assert "duplicate tool sarif detected: codeql" in str(exc.value) + def test_bad_sarif(self, tmpdir, caplog): + sarif_file = Path("tests") / "samples" / "semgrep.sarif" + bad_json = tmpdir / "bad.sarif" + with open(bad_json, "w") as f: + # remove all { to make a badly formatted json + f.write(sarif_file.read_text(encoding="utf-8").replace("{", "")) + + with pytest.raises(json.JSONDecodeError): + detect_sarif_tools([bad_json]) + assert f"Malformed JSON file: {str(bad_json)}" in caplog.text + + def test_bad_sarif_no_runs_data(self, tmpdir, caplog): + bad_json = tmpdir / "bad.sarif" + data = """ + { + "$schema": "https://docs.oasis-open.org/sarif/sarif/v2.1.0/os/schemas/sarif-schema-2.1.0.json", + "version": "2.1.0" + } + """ + with open(bad_json, "w") as f: + f.write(data) + + with pytest.raises(KeyError): + detect_sarif_tools([bad_json]) + assert f"Sarif file without `runs` data: {str(bad_json)}" in caplog.text + def test_two_sarifs_different_tools(self): results = detect_sarif_tools( [