diff --git a/oras/provider.py b/oras/provider.py index d53f61f..10a68fa 100644 --- a/oras/provider.py +++ b/oras/provider.py @@ -598,6 +598,7 @@ def chunked_upload( """ # Start an upload session headers = {"Content-Type": "application/octet-stream", "Content-Length": "0"} + headers.update(self.headers) upload_url = f"{self.prefix}://{container.upload_blob_url()}" r = self.do_request(upload_url, "POST", headers=headers) @@ -611,11 +612,6 @@ def chunked_upload( start = 0 with open(blob, "rb") as fd: for chunk in oras.utils.read_in_chunks(fd, chunk_size=chunk_size): - print("uploading chunk starting at " + str(start)) - - if not chunk: - break - end = start + len(chunk) - 1 content_range = "%s-%s" % (start, end) headers = { @@ -623,13 +619,19 @@ def chunked_upload( "Content-Length": str(len(chunk)), "Content-Type": "application/octet-stream", } + headers.update(self.headers) # Important to update with auth token if acquired # TODO call to auth here start = end + 1 self._check_200_response( - self.do_request(session_url, "PATCH", data=chunk, headers=headers) + r := self.do_request( + session_url, "PATCH", data=chunk, headers=headers + ) ) + session_url = self._get_location(r, container) + if not session_url: + raise ValueError(f"Issue retrieving session url: {r.json()}") # Finally, issue a PUT request to close blob session_url = oras.utils.append_url_params( diff --git a/oras/tests/test_provider.py b/oras/tests/test_provider.py index 2babd5a..7652ae0 100644 --- a/oras/tests/test_provider.py +++ b/oras/tests/test_provider.py @@ -3,6 +3,8 @@ __license__ = "Apache-2.0" import os +import platform +import subprocess from pathlib import Path import pytest @@ -13,7 +15,8 @@ import oras.provider import oras.utils -here = os.path.abspath(os.path.dirname(__file__)) +here = Path(__file__).resolve().parent +OS_GB = "g" if platform.uname()[0].lower() == "darwin" else "G" @pytest.mark.with_auth(False) @@ -62,6 +65,55 @@ def test_annotated_registry_push(tmp_path, registry, credentials, target): ) +@pytest.mark.with_auth(False) +def test_chunked_push(tmp_path, registry, credentials, target): + """ + Basic tests for oras chunked push + """ + # Direct access to registry functions + client = oras.client.OrasClient(hostname=registry, insecure=True) + artifact = os.path.join(here, "artifact.txt") + + assert os.path.exists(artifact) + + res = client.push(files=[artifact], target=target, do_chunked=True) + assert res.status_code in [200, 201, 202] + + files = client.pull(target, outdir=tmp_path) + assert str(tmp_path / "artifact.txt") in files + assert oras.utils.get_file_hash(artifact) == oras.utils.get_file_hash(files[0]) + + # large file upload + tmp_chunked = here / "chunked" + try: + subprocess.run( + [ + "dd", + "if=/dev/null", + f"of={tmp_chunked}", + "bs=1", + "count=0", + f"seek=15{OS_GB}", + ], + ) + + res = client.push(files=[tmp_chunked], target=target, do_chunked=True) + assert res.status_code in [200, 201, 202] + + files = client.pull(target, outdir=tmp_path / "download") + download = str(tmp_path / "download/chunked") + assert download in files + assert oras.utils.get_file_hash(str(tmp_chunked)) == oras.utils.get_file_hash( + download + ) + finally: + tmp_chunked.unlink() + + # File that doesn't exist + with pytest.raises(FileNotFoundError): + res = client.push(files=[tmp_path / "none"], target=target) + + def test_parse_manifest(registry): """ Test parse manifest function.