From ca8d3b961e3bb28386407f25938788318e737478 Mon Sep 17 00:00:00 2001 From: Jean-Christophe Morin Date: Sat, 11 Jan 2025 15:30:25 -0500 Subject: [PATCH] Add a new patching mechanism by introducing a new patches hook Signed-off-by: Jean-Christophe Morin --- pyproject.toml | 5 +- scripts/get_pyside6_files.py | 313 +++++++++++++++++++++++++++ src/rez_pip/cli.py | 9 +- src/rez_pip/data/patches/__init__.py | 0 src/rez_pip/patch.py | 51 +++++ src/rez_pip/plugins/__init__.py | 14 ++ 6 files changed, 389 insertions(+), 3 deletions(-) create mode 100644 scripts/get_pyside6_files.py create mode 100644 src/rez_pip/data/patches/__init__.py create mode 100644 src/rez_pip/patch.py diff --git a/pyproject.toml b/pyproject.toml index 3dcbbc5..00b88c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,10 @@ dependencies = [ "importlib_metadata>=4.6 ; python_version < '3.10'", # 1.3 introduces type hints. "pluggy>=1.2", - "typing-extensions; python_version < '3.8'" + "typing-extensions; python_version < '3.8'", + # Patches are finicky... Let's lock on the current latest version. + # We could always relax later if needed. + "patch-ng==1.18.1", ] classifiers = [ diff --git a/scripts/get_pyside6_files.py b/scripts/get_pyside6_files.py new file mode 100644 index 0000000..6f249cc --- /dev/null +++ b/scripts/get_pyside6_files.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import os +import ast +import sys +import bisect +import typing +import difflib +import zipfile +import tempfile +import itertools +import contextlib +import subprocess + +import requests +import requests.models +import packaging.utils + + +# Token from https://github.com/pypa/pip/blob/bc553db53c264abe3bb63c6bcd6fc6f303c6f6e3/src/pip/_internal/network/lazy_wheel.py +class LazyZipOverHTTP: + """File-like object mapped to a ZIP file over HTTP. + + This uses HTTP range requests to lazily fetch the file's content, + which is supposed to be fed to ZipFile. If such requests are not + supported by the server, raise HTTPRangeRequestUnsupported + during initialization. + """ + + def __init__( + self, + url: str, + session: requests.Session, + chunk_size: int = requests.models.CONTENT_CHUNK_SIZE, + ) -> None: + head = session.head(url, headers={"Accept-Encoding": "identity"}) + head.raise_for_status() + assert head.status_code == 200 + self._session, self._url, self._chunk_size = session, url, chunk_size + self._length = int(head.headers["Content-Length"]) + self._file = tempfile.NamedTemporaryFile() + self.truncate(self._length) + self._left: typing.List[int] = [] + self._right: typing.List[int] = [] + if "bytes" not in head.headers.get("Accept-Ranges", "none"): + raise ValueError("range request is not supported") + self._check_zip() + + @property + def mode(self) -> str: + """Opening mode, which is always rb.""" + return "rb" + + @property + def name(self) -> str: + """Path to the underlying file.""" + return self._file.name + + def seekable(self) -> bool: + """Return whether random access is supported, which is True.""" + return True + + def close(self) -> None: + """Close the file.""" + self._file.close() + + @property + def closed(self) -> bool: + """Whether the file is closed.""" + return self._file.closed + + def read(self, size: int = -1) -> bytes: + """Read up to size bytes from the object and return them. + + As a convenience, if size is unspecified or -1, + all bytes until EOF are returned. Fewer than + size bytes may be returned if EOF is reached. + """ + download_size = max(size, self._chunk_size) + start, length = self.tell(), self._length + stop = length if size < 0 else min(start + download_size, length) + start = max(0, stop - download_size) + self._download(start, stop - 1) + return self._file.read(size) + + def readable(self) -> bool: + """Return whether the file is readable, which is True.""" + return True + + def seek(self, offset: int, whence: int = 0) -> int: + """Change stream position and return the new absolute position. + + Seek to offset relative position indicated by whence: + * 0: Start of stream (the default). pos should be >= 0; + * 1: Current position - pos may be negative; + * 2: End of stream - pos usually negative. + """ + return self._file.seek(offset, whence) + + def tell(self) -> int: + """Return the current position.""" + return self._file.tell() + + def truncate(self, size: typing.Optional[int] = None) -> int: + """Resize the stream to the given size in bytes. + + If size is unspecified resize to the current position. + The current stream position isn't changed. + + Return the new file size. + """ + return self._file.truncate(size) + + def writable(self) -> bool: + """Return False.""" + return False + + def __enter__(self) -> "LazyZipOverHTTP": + self._file.__enter__() + return self + + def __exit__(self, *exc: Any) -> None: + self._file.__exit__(*exc) + + @contextlib.contextmanager + def _stay(self) -> typing.Generator[None, None, None]: + """Return a context manager keeping the position. + + At the end of the block, seek back to original position. + """ + pos = self.tell() + try: + yield + finally: + self.seek(pos) + + def _check_zip(self) -> None: + """Check and download until the file is a valid ZIP.""" + end = self._length - 1 + for start in reversed(range(0, end, self._chunk_size)): + self._download(start, end) + with self._stay(): + try: + # For read-only ZIP files, ZipFile only needs + # methods read, seek, seekable and tell. + zipfile.ZipFile(self) + except zipfile.BadZipFile: + pass + else: + break + + def _stream_response( + self, + start: int, + end: int, + base_headers: typing.Dict[str, str] = {"Accept-Encoding": "identity"}, + ) -> requests.Response: + """Return HTTP response to a range request from start to end.""" + headers = base_headers.copy() + headers["Range"] = f"bytes={start}-{end}" + # TODO: Get range requests to be correctly cached + headers["Cache-Control"] = "no-cache" + return self._session.get(self._url, headers=headers, stream=True) + + def _merge( + self, start: int, end: int, left: int, right: int + ) -> typing.Generator[typing.Tuple[int, int], None, None]: + """Return a generator of intervals to be fetched. + + Args: + start (int): Start of needed interval + end (int): End of needed interval + left (int): Index of first overlapping downloaded data + right (int): Index after last overlapping downloaded data + """ + lslice, rslice = self._left[left:right], self._right[left:right] + i = start = min([start] + lslice[:1]) + end = max([end] + rslice[-1:]) + for j, k in zip(lslice, rslice): + if j > i: + yield i, j - 1 + i = k + 1 + if i <= end: + yield i, end + self._left[left:right], self._right[left:right] = [start], [end] + + def _download(self, start: int, end: int) -> None: + """Download bytes from start to end inclusively.""" + with self._stay(): + left = bisect.bisect_left(self._right, start) + right = bisect.bisect_right(self._left, end) + for start, end in self._merge(start, end, left, right): + response = self._stream_response(start, end) + response.raise_for_status() + self.seek(start) + for chunk in response.iter_content(self._chunk_size): + self._file.write(chunk) + + +# https://stackoverflow.com/a/66733795 +def compare_ast( + node1: ast.expr | list[ast.expr], node2: ast.expr | list[ast.expr] +) -> bool: + if type(node1) is not type(node2): + return False + + if isinstance(node1, ast.AST): + for k, v in vars(node1).items(): + if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}: + continue + if not compare_ast(v, getattr(node2, k)): + return False + return True + + elif isinstance(node1, list) and isinstance(node2, list): + return all( + compare_ast(n1, n2) for n1, n2 in itertools.zip_longest(node1, node2) + ) + else: + return node1 == node2 + + +def run(): + with requests.get( + "https://pypi.org/simple/pyside6", + headers={"Accept": "application/vnd.pypi.simple.v1+json"}, + ) as resp: + resp.raise_for_status() + + data = resp.json() + + versions: list[str] = [] + for entry in data["files"]: + if not entry["filename"].endswith(".whl"): + continue + + name, version, buildtag, tags = packaging.utils.parse_wheel_filename( + entry["filename"] + ) + if version.pre: + continue + + if not any( + tag.platform.startswith("win_") and not tag.interpreter.startswith("pp") + for tag in tags + ): + continue + + print(entry["filename"]) + + # Store raw files in patches/data/ + # This will allow us ot inspect them before deciding on how + # to create patches. + + directory = os.path.join("patches", "data", str(version)) + os.makedirs(directory, exist_ok=True) + + session = requests.Session() + wheel = LazyZipOverHTTP(entry["url"], session) + with zipfile.ZipFile(wheel) as zf: + for info in zf.infolist(): + if info.filename != "PySide6/__init__.py": + continue + + with open( + os.path.join(directory, os.path.basename(info.filename)), "wb" + ) as f: + f.write(zf.read(info)) + break + + versions.append(str(version)) + + print("Comparing files") + first = versions.pop(0) + + while len(versions) > 1: + leftFile = f"patches/data/{versions[0]}/__init__.py" + rightFile = f"patches/data/{versions[1]}/__init__.py" + with open(leftFile, "r") as lfh, open(rightFile, "r") as rfh: + lhs = ast.parse(lfh.read()) + rhs = ast.parse(rfh.read()) + + leftAST = next( + node + for node in lhs.body + if isinstance(node, ast.FunctionDef) + and node.name == "_additional_dll_directories" + ) + + rightAST = next( + node + for node in rhs.body + if isinstance(node, ast.FunctionDef) + and node.name == "_additional_dll_directories" + ) + + if not compare_ast(leftAST, rightAST): + print( + f"{versions[0]} and {versions[1]}'s _additional_dll_directories function differ" + ) + leftCode = ast.unparse(leftAST).splitlines(keepends=True) + rightCode = ast.unparse(rightAST).splitlines(keepends=True) + + result = difflib.unified_diff( + leftCode, rightCode, fromfile=leftFile, tofile=rightFile + ) + + sys.stdout.writelines(result) + + versions.pop(0) + + +run() diff --git a/src/rez_pip/cli.py b/src/rez_pip/cli.py index a0fc2cb..701516e 100644 --- a/src/rez_pip/cli.py +++ b/src/rez_pip/cli.py @@ -23,6 +23,7 @@ import rez_pip.pip import rez_pip.rez import rez_pip.data +import rez_pip.patch import rez_pip.plugins import rez_pip.install import rez_pip.download @@ -251,12 +252,16 @@ def _run(args: argparse.Namespace, pipArgs: typing.List[str], pipWorkArea: str) ): for group in packageGroups: for package in group.packages: - _LOG.info(f"[bold]Installing {package.name} {package.path}") + _LOG.info(f"[bold]Installing {package.name!r} {package.path!r}") + targetPath = os.path.join(installedWheelsDir, package.name) dist = rez_pip.install.installWheel( package, package.path, - os.path.join(installedWheelsDir, package.name), + targetPath, ) + + rez_pip.patch.patch(dist, targetPath) + group.dists.append(dist) with rich.get_console().status("[bold]Creating rez packages..."): diff --git a/src/rez_pip/data/patches/__init__.py b/src/rez_pip/data/patches/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rez_pip/patch.py b/src/rez_pip/patch.py new file mode 100644 index 0000000..20b03b8 --- /dev/null +++ b/src/rez_pip/patch.py @@ -0,0 +1,51 @@ +import os +import typing +import logging + +import patch_ng + +import rez_pip.data.patches +import rez_pip.plugins +import rez_pip.exceptions +from rez_pip.compat import importlib_metadata + +_LOG = logging.getLogger(__name__) + + +class PatchError(rez_pip.exceptions.RezPipError): + pass + + +def getBuiltinPatchesDir() -> str: + """Get the built-in patches directory""" + return os.path.dirname(rez_pip.data.patches.__file__) + + +def patch(dist: importlib_metadata.Distribution, path: str): + """Patch an installed package (wheel)""" + _LOG.debug(f"[bold]Attempting to patch {dist.name!r} at {path!r}") + patchesGroups: typing.List[list[str]] = rez_pip.plugins.getHook().patches( + dist=dist, path=path + ) + + # Flatten the list + patches = [path for group in patchesGroups for path in group] + + if not patches: + _LOG.debug(f"No patches found") + return + + _LOG.info(f"Applying {len(patches)} patches for {dist.name!r} at {path!r}") + + for patch in patches: + _LOG.info(f"Applying patch {patch!r} on {path!r}") + + if not os.path.isabs(patch): + raise PatchError(f"{patch!r} is not an absolute path") + + if not os.path.exists(patch): + raise PatchError(f"Patch at {patch!r} does not exist") + + patchset = patch_ng.fromfile(patch) + if not patchset.apply(root=path): + raise PatchError(f"Failed to apply patch {patch!r} on {path!r}") diff --git a/src/rez_pip/plugins/__init__.py b/src/rez_pip/plugins/__init__.py index 34d38a7..dc14f0d 100644 --- a/src/rez_pip/plugins/__init__.py +++ b/src/rez_pip/plugins/__init__.py @@ -81,6 +81,20 @@ def groupPackages( # type: ignore[empty-body] :returns: A list of package groups. """ + @hookspec + def patches( + self, dist: rez_pip.compat.importlib_metadata.Distribution, path: str + ) -> typing.Sequence[str]: + """ + Provide paths to patches to be applied on the source code of a package. + + :param dist: Python distribution. + :param path: Root path of the installed content. + """ + # TODO: This will alter files (obviously) and change their hashes. + # This could be a problem to verify the integrity of the package. + # https://packaging.python.org/en/latest/specifications/recording-installed-packages/#the-record-file + @hookspec def cleanup( self, dist: rez_pip.compat.importlib_metadata.Distribution, path: str