diff --git a/blueskysight/firehose.py b/blueskysight/firehose.py index dc77285..78174c6 100644 --- a/blueskysight/firehose.py +++ b/blueskysight/firehose.py @@ -1,26 +1,17 @@ import asyncio import json -import re import struct import typing as t from base64 import b32encode from io import BytesIO import websockets -from pyvulnerabilitylookup import PyVulnerabilityLookup - -from blueskysight import config -from blueskysight.utils import get_post_url, remove_case_insensitive_duplicates - -vulnerability_pattern = re.compile( - r"\b(CVE-\d{4}-\d{4,})\b" # CVE pattern - r"|\b(GHSA-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4})\b" # GHSA pattern - r"|\b(PYSEC-\d{4}-\d{2,5})\b" # PYSEC pattern - r"|\b(GSD-\d{4}-\d{4,5})\b" # GSD pattern - r"|\b(wid-sec-w-\d{4}-\d{4})\b" # CERT-Bund pattern - r"|\b(cisco-sa-\d{8}-[a-zA-Z0-9]+)\b" # CISCO pattern - r"|\b(RHSA-\d{4}:\d{4})\b", # RedHat pattern - re.IGNORECASE, + +from blueskysight.utils import ( + get_post_url, + push_sighting_to_vulnerability_lookup, + remove_case_insensitive_duplicates, + vulnerability_pattern, ) BSKY_FIREHOSE = "wss://bsky.network/xrpc/com.atproto.sync.subscribeRepos" @@ -286,27 +277,6 @@ def read_firehose_frame(frame: bytes) -> tuple[dict, dict]: return header, body -def push_sighting_to_vulnerability_lookup(status_uri, vulnerability_ids): - """Create a sighting from an incoming status and push it to the Vulnerability Lookup instance.""" - print("Pushing sighting to Vulnerability Lookup…") - vuln_lookup = PyVulnerabilityLookup( - config.vulnerability_lookup_base_url, token=config.vulnerability_auth_token - ) - for vuln in vulnerability_ids: - # Create the sighting - sighting = {"type": "seen", "source": status_uri, "vulnerability": vuln} - - # Post the JSON to Vulnerability Lookup - try: - r = vuln_lookup.create_sighting(sighting=sighting) - if "message" in r: - print(r["message"]) - except Exception as e: - print( - f"Error when sending POST request to the Vulnerability Lookup server:\n{e}" - ) - - async def firehose(): """ Connects to the Bluesky firehose WebSocket stream, processes frames, diff --git a/blueskysight/jetstream.py b/blueskysight/jetstream.py new file mode 100644 index 0000000..8557bd9 --- /dev/null +++ b/blueskysight/jetstream.py @@ -0,0 +1,258 @@ +import asyncio +import json +import os +import platform +import typing as t +from pathlib import Path +from urllib.parse import urlencode + +import zstandard as zstd +from httpx_ws import connect_ws + +from blueskysight.utils import ( + get_post_url, + push_sighting_to_vulnerability_lookup, + remove_case_insensitive_duplicates, + vulnerability_pattern, +) + +PUBLIC_URL_FMT = "wss://jetstream{instance}.{geo}.bsky.network/subscribe" + + +def get_public_jetstream_base_url( + geo: t.Literal["us-west", "us-east"] = "us-west", + instance: int = 1, +) -> str: + """Return a public Jetstream URL with the given options.""" + return PUBLIC_URL_FMT.format(geo=geo, instance=instance) + + +def get_jetstream_query_url( + base_url: str, + collections: t.Sequence[str], + dids: t.Sequence[str], + cursor: int, + compress: bool, +) -> str: + """Return a Jetstream URL with the given query parameters.""" + query = [("wantedCollections", collection) for collection in collections] + query += [("wantedDids", did) for did in dids] + if cursor: # Only include the cursor if it is non-zero. + query.append(("cursor", str(cursor))) + if compress: + query.append(("compress", "true")) + query_enc = urlencode(query) + return f"{base_url}?{query_enc}" if query_enc else base_url + + +# +# Utilities to manage zstd decompression of data (use the --compress flag to enable) +# + +# Jetstream uses a custom zstd dict to improve compression; here's where to find it: +ZSTD_DICT_URL = "https://raw.githubusercontent.com/bluesky-social/jetstream/main/pkg/models/zstd_dictionary" + + +def get_cache_directory(app_name: str) -> Path: + """ + Determines the appropriate cache directory for the application, cross-platform. + + Args: + app_name (str): The name of your application. + + Returns: + Path: The path to the cache directory. + """ + if platform.system() == "Windows": + # Use %LOCALAPPDATA% for Windows + base_cache_dir = os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local") + else: + # Use XDG_CACHE_HOME or fallback to ~/.cache for Unix-like systems + base_cache_dir = os.getenv("XDG_CACHE_HOME", Path.home() / ".cache") + + cache_dir = Path(base_cache_dir) / app_name + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def download_zstd_dict(zstd_dict_path: Path): + """ + Download the Zstandard dictionary from the Jetstream repository. + + Args: + zstd_dict_path (Path): The path to save the Zstandard dictionary. + """ + import httpx + + with httpx.stream("GET", ZSTD_DICT_URL) as response: + with zstd_dict_path.open("wb") as f: + for chunk in response.iter_bytes(): + f.write(chunk) + + +def get_zstd_decompressor() -> zstd.ZstdDecompressor: + """Get a Zstandard decompressor with a pre-trained dictionary.""" + cache_dir = get_cache_directory("jetstream") + cache_dir.mkdir(parents=True, exist_ok=True) + zstd_dict_path = cache_dir / "zstd_dict.bin" + + if not zstd_dict_path.exists(): + download_zstd_dict(zstd_dict_path) + + with zstd_dict_path.open("rb") as f: + zstd_dict = f.read() + + dict_data = zstd.ZstdCompressionDict(zstd_dict) + return zstd.ZstdDecompressor(dict_data=dict_data) + + +# +# Code to resolve an ATProto handle to a DID +# + + +def raw_handle(handle: str) -> str: + """Returns a raw ATProto handle, without the @ prefix.""" + return handle[1:] if handle.startswith("@") else handle + + +def resolve_handle_to_did_dns(handle: str) -> str | None: + """ + Resolves an ATProto handle to a DID using DNS. + + Returns None if the handle is not found. + + Raises exceptions if network requests fail. + """ + import dns.resolver + + try: + answers = dns.resolver.resolve(f"_atproto.{handle}", "TXT") + except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN): + return None + + for answer in answers: + txt = answer.to_text() + if txt.startswith('"did='): + return txt[5:-1] + + return None + + +def resolve_handle_to_did_well_known(handle: str) -> str | None: + """ + Resolves an ATProto handle to a DID using a well-known endpoint. + + Returns None if the handle is not found. + + Raises exceptions if network requests fail. + """ + import httpx + + try: + response = httpx.get(f"https://{handle}/.well-known/atproto-did", timeout=5) + response.raise_for_status() + except (httpx.ConnectError, httpx.HTTPStatusError, httpx.TimeoutException): + return None + + return response.text.strip() + + +def resolve_handle_to_did(handle: str) -> str | None: + """ + Resolves an ATProto handle, like @bsky.app, to a DID. + + We resolve as follows: + + 1. Check the _atproto DNS TXT record for the handle. + 2. If not found, query for a .well-known/atproto-did + + Returns None if the handle is not found. + + Raises exceptions if network requests fail. + """ + handle = raw_handle(handle) + maybe_did = resolve_handle_to_did_dns(handle) + maybe_did = maybe_did or resolve_handle_to_did_well_known(handle) + return maybe_did + + +def require_resolve_handle_to_did(handle: str) -> str: + """ + Resolves an ATProto handle to a DID, raising an error if not found. + + Raises a ValueError if the handle is not found. + """ + did = resolve_handle_to_did(handle) + if did is None: + raise ValueError(f"Could not resolve handle '{handle}' to a DID.") + return did + + +def extract_vulnerability_ids(content): + """ + Extracts vulnerability IDs from post content using the predefined regex pattern. + """ + matches = vulnerability_pattern.findall(content) + # Flatten the list of tuples to get only non-empty matched strings + return remove_case_insensitive_duplicates( + [match for match_tuple in matches for match in match_tuple if match] + ) + + +async def jetstream( + collections: t.Sequence[str] = ["app.bsky.feed.post"], + dids: t.Sequence[str] = [], + handles: t.Sequence[str] = [], + cursor: int = 0, + base_url: str | None = None, + geo: t.Literal["us-west", "us-east"] = "us-west", + instance: int = 1, + compress: bool = False, +): + """Emit Jetstream JSON messages to the console, one per line.""" + # Resolve handles and form the final list of DIDs to subscribe to. + handle_dids = [require_resolve_handle_to_did(handle) for handle in handles] + dids = list(dids) + handle_dids + + # Build the Zstandard decompressor if compression is enabled. + decompressor = get_zstd_decompressor() if compress else None + + # Form the Jetstream URL to connect to. + base_url = base_url or get_public_jetstream_base_url(geo, instance) + url = get_jetstream_query_url(base_url, collections, dids, cursor, compress) + + with connect_ws(url) as ws: + while True: + if decompressor: + message = ws.receive_bytes() + with decompressor.stream_reader(message) as reader: + message = reader.read() + message = message.decode("utf-8") + else: + message = ws.receive_text() + json_message = json.loads(message) + if ( + "commit" in json_message + and json_message["commit"]["operation"] == "create" + ): + content = json_message["commit"]["record"].get("text", "") + if content: + vulnerability_ids = extract_vulnerability_ids(content) + if vulnerability_ids: + uri = f'at://{json_message["did"]}/app.bsky.feed.post/{json_message["commit"]["rkey"]}' + url = await get_post_url(uri) + print(f"Post content: {content}") + print(f"Post URL: {url}") + print( + f"Vulnerability IDs detected: {', '.join(vulnerability_ids)}" + ) + push_sighting_to_vulnerability_lookup(url, vulnerability_ids) + + +def main(): + asyncio.run(jetstream()) + + +if __name__ == "__main__": + main() diff --git a/blueskysight/stream.py b/blueskysight/stream.py index 3b129c1..bd55498 100644 --- a/blueskysight/stream.py +++ b/blueskysight/stream.py @@ -1,51 +1,18 @@ import asyncio import io -import re import websockets -from pyvulnerabilitylookup import PyVulnerabilityLookup -from blueskysight import config from blueskysight.utils import ( enumerate_mst_records, get_post_url, parse_car, parse_dag_cbor_object, + push_sighting_to_vulnerability_lookup, remove_case_insensitive_duplicates, + vulnerability_pattern, ) -vulnerability_pattern = re.compile( - r"\b(CVE-\d{4}-\d{4,})\b" # CVE pattern - r"|\b(GHSA-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4})\b" # GHSA pattern - r"|\b(PYSEC-\d{4}-\d{2,5})\b" # PYSEC pattern - r"|\b(GSD-\d{4}-\d{4,5})\b" # GSD pattern - r"|\b(wid-sec-w-\d{4}-\d{4})\b" # CERT-Bund pattern - r"|\b(cisco-sa-\d{8}-[a-zA-Z0-9]+)\b" # CISCO pattern - r"|\b(RHSA-\d{4}:\d{4})\b", # RedHat pattern - re.IGNORECASE, -) - - -def push_sighting_to_vulnerability_lookup(status_uri, vulnerability_ids): - """Create a sighting from an incoming status and push it to the Vulnerability Lookup instance.""" - print("Pushing sighting to Vulnerability Lookup…") - vuln_lookup = PyVulnerabilityLookup( - config.vulnerability_lookup_base_url, token=config.vulnerability_auth_token - ) - for vuln in vulnerability_ids: - # Create the sighting - sighting = {"type": "seen", "source": status_uri, "vulnerability": vuln} - - # Post the JSON to Vulnerability Lookup - try: - r = vuln_lookup.create_sighting(sighting=sighting) - if "message" in r: - print(r["message"]) - except Exception as e: - print( - f"Error when sending POST request to the Vulnerability Lookup server:\n{e}" - ) - async def stream(): """ diff --git a/blueskysight/utils.py b/blueskysight/utils.py index e9bfabf..6a50880 100644 --- a/blueskysight/utils.py +++ b/blueskysight/utils.py @@ -1,10 +1,46 @@ import base64 import hashlib import io +import re import struct from enum import Enum import httpx +from pyvulnerabilitylookup import PyVulnerabilityLookup + +from blueskysight import config + +vulnerability_pattern = re.compile( + r"\b(CVE-\d{4}-\d{4,})\b" # CVE pattern + r"|\b(GHSA-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4}-[a-zA-Z0-9]{4})\b" # GHSA pattern + r"|\b(PYSEC-\d{4}-\d{2,5})\b" # PYSEC pattern + r"|\b(GSD-\d{4}-\d{4,5})\b" # GSD pattern + r"|\b(wid-sec-w-\d{4}-\d{4})\b" # CERT-Bund pattern + r"|\b(cisco-sa-\d{8}-[a-zA-Z0-9]+)\b" # CISCO pattern + r"|\b(RHSA-\d{4}:\d{4})\b", # RedHat pattern + re.IGNORECASE, +) + + +def push_sighting_to_vulnerability_lookup(status_uri, vulnerability_ids): + """Create a sighting from an incoming status and push it to the Vulnerability Lookup instance.""" + print("Pushing sighting to Vulnerability Lookup…") + vuln_lookup = PyVulnerabilityLookup( + config.vulnerability_lookup_base_url, token=config.vulnerability_auth_token + ) + for vuln in vulnerability_ids: + # Create the sighting + sighting = {"type": "seen", "source": status_uri, "vulnerability": vuln} + + # Post the JSON to Vulnerability Lookup + try: + r = vuln_lookup.create_sighting(sighting=sighting) + if "message" in r: + print(r["message"]) + except Exception as e: + print( + f"Error when sending POST request to the Vulnerability Lookup server:\n{e}" + ) def remove_case_insensitive_duplicates(input_list): diff --git a/poetry.lock b/poetry.lock index 17861ba..c25f5b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -282,6 +282,31 @@ files = [ {file = "charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3"}, ] +[[package]] +name = "click" +version = "8.1.8" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, + {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + [[package]] name = "distlib" version = "0.3.9" @@ -395,6 +420,23 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-ws" +version = "0.7.1" +description = "WebSockets support for HTTPX" +optional = false +python-versions = ">=3.9" +files = [ + {file = "httpx_ws-0.7.1-py3-none-any.whl", hash = "sha256:7970e470840d8e6c17bd45ed4e7af06f9144a4a9decab2ff226f3ff9accb65b4"}, + {file = "httpx_ws-0.7.1.tar.gz", hash = "sha256:72f355d4b9b16d8fa59e5e68efdfcb1f3c7dca944901b373791245c8f67f9f95"}, +] + +[package.dependencies] +anyio = ">=4" +httpcore = ">=1.0.4" +httpx = ">=0.23.1" +wsproto = "*" + [[package]] name = "identify" version = "2.6.3" @@ -868,6 +910,20 @@ files = [ {file = "websockets-14.1.tar.gz", hash = "sha256:398b10c77d471c0aab20a845e7a60076b6390bfdaac7a6d2edb0d2c59d75e8d8"}, ] +[[package]] +name = "wsproto" +version = "1.2.0" +description = "WebSockets state-machine based protocol implementation" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, + {file = "wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065"}, +] + +[package.dependencies] +h11 = ">=0.9.0,<1" + [[package]] name = "zstandard" version = "0.23.0" @@ -983,4 +1039,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e9d568eed359b4b5e60b39d5d3cc47dd337782d837b50f162a86192d8d5c4693" +content-hash = "135dc644958fd69a92e2deeca1cbc345ed22b808308cbf3541191e70e4db010f" diff --git a/pyproject.toml b/pyproject.toml index a3b91c6..b5bad3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ include = [ [tool.poetry.scripts] BlueSkySight-Firehose = "blueskysight.stream:main" BlueSkySight-Firehose-v2 = "blueskysight.firehose:main" +BlueSkySight-Jetstream = "blueskysight.jetstream:main" [tool.poetry.dependencies] python = "^3.10" @@ -42,6 +43,8 @@ websockets = "^14.1" cbor2 = "^5.6.5" zstandard = "^0.23.0" httpx = "^0.28.1" +httpx-ws = "^0.7.1" +click = "^8.1.8" [tool.poetry.group.dev.dependencies] mypy = "^1.13.0"