Skip to content

Commit

Permalink
Merge pull request #302 from vipyrsec/better-match-info
Browse files Browse the repository at this point in the history
Better match info
  • Loading branch information
jonathan-d-zhang authored Aug 10, 2024
2 parents b6a1b3e + 520c680 commit 981b528
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 4 deletions.
29 changes: 29 additions & 0 deletions alembic/versions/587c186d91ee_better_match_information.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""better-match-information
Revision ID: 587c186d91ee
Revises: 6991bcb18f89
Create Date: 2024-07-27 19:51:33.408128
"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "587c186d91ee"
down_revision = "6991bcb18f89"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scans", sa.Column("files", postgresql.JSONB(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scans", "files")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion docs/source/database_schema.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Database Schema
pending_by text,
finished_by text,
commit_hash text,
fail_reason text
fail_reason text,
files jsonb
);
ALTER TABLE ONLY public.download_urls
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@ omit = [

[tool.coverage.report]
fail_under = 100
exclude_also = ["if TYPE_CHECKING:"]
1 change: 1 addition & 0 deletions src/mainframe/endpoints/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def submit_results(
scan.score = result.score
scan.finished_by = auth.subject
scan.commit_hash = result.commit
scan.files = result.files

# These are the rules that already have an entry in the database
rules = session.scalars(select(Rule).where(Rule.name.in_(result.rules_matched))).all()
Expand Down
26 changes: 26 additions & 0 deletions src/mainframe/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
"""Database models."""

from typing import Optional, Any, Type
from pydantic import BaseModel
from sqlalchemy import Dialect, TypeDecorator
from sqlalchemy.dialects.postgresql import JSONB


class Pydantic[T: BaseModel](TypeDecorator[T]):
"""TypeDecorator to convert between Pydantic models and JSONB."""

impl = JSONB
cache_ok = True

def __init__(self, pydantic_type: Type[T]):
super().__init__()
self.pydantic_type = pydantic_type

def process_bind_param(self, value: Optional[T], dialect: Dialect) -> dict[str, Any]:
if value:
return value.model_dump()
else:
return {}

def process_result_value(self, value: Any, dialect: Dialect) -> Optional[T]:
if value:
return self.pydantic_type.model_validate(value)
5 changes: 5 additions & 0 deletions src/mainframe/models/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
relationship,
)

from mainframe.models import Pydantic
from mainframe.models.schemas import Files


class Base(MappedAsDataclass, DeclarativeBase, kw_only=True):
pass
Expand Down Expand Up @@ -99,6 +102,8 @@ class Scan(Base):

commit_hash: Mapped[Optional[str]] = mapped_column(default=None)

files: Mapped[Optional[Files]] = mapped_column(Pydantic(Files), default=None)


Index(None, Scan.status, postgresql_where=or_(Scan.status == Status.QUEUED, Scan.status == Status.PENDING))

Expand Down
54 changes: 51 additions & 3 deletions src/mainframe/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,54 @@
from __future__ import annotations

import datetime
from enum import Enum
from typing import Any, Optional
from typing import TYPE_CHECKING, Annotated, Any, Optional

from pydantic import BaseModel, Field, field_serializer, ConfigDict, RootModel

if TYPE_CHECKING:
from mainframe.models.orm import Scan

type MetaValue = int | float | bool | str | bytes


class Range(BaseModel):
"""Represents the inclusive range in the source file that was matched."""

start: int
end: int


class Match(BaseModel):
"""Represents a specific match by a pattern in a rule."""

range: Range
data: list[Annotated[int, Field(ge=0, lt=256)]]

from pydantic import BaseModel, Field, field_serializer, ConfigDict

from .orm import Scan
class PatternMatch(BaseModel):
"""Represents the data matched by a pattern inside a rule."""

identifier: str
matches: list[Match]


class RuleMatch(BaseModel):
"""Represents the matches of a rule on a file"""

identifier: str
patterns: list[PatternMatch]
metadata: dict[str, MetaValue]


class File(BaseModel):
"""Represents a file and the rule matches for it."""

path: str
matches: list[RuleMatch]


Files = RootModel[list[File]]


class ServerMetadata(BaseModel):
Expand Down Expand Up @@ -44,6 +88,8 @@ class Package(BaseModel):

commit_hash: Optional[str]

files: Optional[Files]

@classmethod
def from_db(cls, scan: Scan):
return cls(
Expand All @@ -64,6 +110,7 @@ def from_db(cls, scan: Scan):
finished_at=scan.finished_at,
finished_by=scan.finished_by,
commit_hash=scan.commit_hash,
files=scan.files,
)

@field_serializer(
Expand Down Expand Up @@ -132,6 +179,7 @@ class PackageScanResult(PackageSpecifier):
score: int = 0
inspector_url: Optional[str] = None
rules_matched: list[str] = []
files: Optional[Files] = None


class PackageScanResultFail(PackageSpecifier):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@
from mainframe.json_web_token import AuthenticationData
from mainframe.models.orm import Scan, Status
from mainframe.models.schemas import (
File,
Files,
Match,
Package,
PackageScanResult,
PackageScanResultFail,
PackageSpecifier,
PatternMatch,
Range,
RuleMatch,
)
from mainframe.rules import Rules

Expand Down Expand Up @@ -80,6 +86,32 @@ def test_package_lookup_rejects_invalid_combinations(
assert e.value.status_code == 400


def test_package_lookup_files(db_session: Session):
"""Test that `lookup_package_info` returns detailed file information."""

range_ = Range(start=0, end=4)
match = Match(range=range_, data=[0xDE, 0xAD, 0xBE, 0xEF])
pattern = PatternMatch(identifier="$pat", matches=[match])
rule = RuleMatch(identifier="rule1", patterns=[pattern], metadata={"author": "remmy", "score": 5})
file = File(path="dist1/a/b.py", matches=[rule])
files = Files([file])
scan = Scan(
name="abc",
version="1.0.0",
status=Status.FINISHED,
queued_by="remmy",
files=files,
)

with db_session.begin():
db_session.add(scan)
db_session.commit()

package = lookup_package_info(db_session, name="abc", version="1.0.0")[0]

assert package.files == files


def test_handle_success(db_session: Session, test_data: list[Scan], auth: AuthenticationData, rules_state: Rules):
job = get_jobs(db_session, auth, rules_state, batch=1)

Expand All @@ -88,13 +120,21 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen
name = job.name
version = job.version

range_ = Range(start=0, end=4)
match = Match(range=range_, data=[0xDE, 0xAD, 0xBE, 0xEF])
pattern = PatternMatch(identifier="$pat", matches=[match])
rule = RuleMatch(identifier="rule1", patterns=[pattern], metadata={"author": "remmy", "score": 5})
file = File(path="dist1/a/b.py", matches=[rule])
files = Files([file])

body = PackageScanResult(
name=job.name,
version=job.version,
commit=rules_state.rules_commit,
score=2,
inspector_url="test inspector url",
rules_matched=["a", "b", "c"],
files=files,
)
submit_results(body, db_session, auth)

Expand All @@ -107,6 +147,7 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen
assert record.score == 2
assert record.inspector_url == "test inspector url"
assert {rule.name for rule in record.rules} == {"a", "b", "c"}
assert record.files == files
else:
assert all(scan.status != Status.QUEUED for scan in test_data)

Expand Down

0 comments on commit 981b528

Please sign in to comment.