From 1d607e820740122bf1d432faf36b220fc656c46d Mon Sep 17 00:00:00 2001 From: Vamsi Thiriveedhi Date: Sat, 23 Mar 2024 23:33:08 +0000 Subject: [PATCH] feat: report download progress there is no way to track download sequentially as s5cmd run or cp command will be locked on a thread. so, download progress is outsourced to a thread and will run simultaneously along side download thread on both series and manifest downloader as the index only contains aws urls, when a manifest contains gcs urls, crdc series instance uuid is extracted from aws urls and queried against the aws urls in the index to get download size. For manifest download, download size is calculated first and the progress is tracked against as a whole. manifest validator in download from manifest is offloaded to a dedicated function that will check not only the first line but every line, and if the manifest has urls from both gcp and aws to raise an exception. s5cmd cp is replaced with sync to gracefully avoid downloading the same data again. as the index only contains aws urls, when a manifest contains gcs urls, crdc series instance uuid is extracted from aws urls and queried against the aws urls in the index to get download size get functions will now return a message that data not found for the values given for a key. queries folder is now removed as they will persist in idc-index-data --- idc_index/index.py | 436 +++++++++++++++++++++++++++++------------- queries/idc_index.sql | 34 ---- 2 files changed, 304 insertions(+), 166 deletions(-) delete mode 100644 queries/idc_index.sql diff --git a/idc_index/index.py b/idc_index/index.py index b1d15e10..991f0e49 100644 --- a/idc_index/index.py +++ b/idc_index/index.py @@ -13,10 +13,15 @@ else: from importlib.metadata import distribution +import threading +import time +from pathlib import Path + import duckdb import idc_index_data import pandas as pd import psutil +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -66,7 +71,11 @@ def _filter_dataframe_by_id(key, dataframe, _id): values = _id if isinstance(_id, str): values = [_id] - return dataframe[dataframe[key].isin(values)].copy() + filtered_df = dataframe[dataframe[key].isin(values)].copy() + if filtered_df.empty: + error_message = f"No data found for the {key} with the values {values}." + raise ValueError(error_message) + return filtered_df @staticmethod def _filter_by_collection_id(df_index, collection_id): @@ -259,47 +268,100 @@ def get_dicom_series(self, studyInstanceUID=None, outputFormat="dict"): return response - def download_dicom_series( - self, seriesInstanceUID, downloadDir, dry_run=False, quiet=True + def _track_download_progress( + self, size_MB: int, downloadDir: str, download_thread: threading.Thread ): """ - Download the files corresponding to the seriesInstanceUID to the specified directory. - - Args: - seriesInstanceUID: string containing the value of DICOM SeriesInstanceUID to filter by - downloadDir: string containing the path to the directory to download the files to - dry_run: boolean indicating if the download should be a dry run (default: False) - quiet: boolean indicating if the output should be suppressed (default: True) + Track progress by continuously checking the downloaded file size and updating the progress bar. + """ + total_size_bytes = size_MB * 10**6 # Convert MB to bytes + pbar = tqdm( + total=total_size_bytes, + unit="B", + unit_scale=True, + desc="Downloading data", + ) + while True: + downloaded_bytes = sum( + f.stat().st_size for f in Path(downloadDir).iterdir() if f.is_file() + ) + pbar.n = min( + downloaded_bytes, total_size_bytes + ) # Prevent the progress bar from exceeding 100% + pbar.refresh() + if not download_thread.is_alive() or pbar.n >= total_size_bytes: + break + time.sleep(0.5) + pbar.close() + + def _download_series_process(self, series_url: str, download_dir: str) -> int: + """ + Download series files using s5cmd sync command. + Sync make sures not to download files again if size and + modified time are unchanged. - Returns: + #https://github.com/peak/s5cmd?tab=readme-ov-file#sync + Args: + series_url (str): AWS Series URL """ - series_url = self.index[self.index["SeriesInstanceUID"] == seriesInstanceUID][ - "series_aws_url" - ].iloc[0] - logger.debug("AWS Bucket Location: " + series_url) cmd = [ self.s5cmdPath, "--no-sign-request", "--endpoint-url", aws_endpoint_url, - "cp", - "--show-progress", + "sync", series_url, - downloadDir, + download_dir, ] - if not dry_run: - process = subprocess.run( - cmd, capture_output=(not quiet), text=(not quiet), check=False + process = subprocess.Popen( + cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True + ) + + process.communicate() # Wait for the process to finish + + return process.returncode + + def download_dicom_series(self, seriesInstanceUID: str, downloadDir: str) -> None: + """ + Download the files corresponding to the seriesInstanceUID to the specified directory. + + Returns: None + + """ + series_df = self.index[self.index["SeriesInstanceUID"] == seriesInstanceUID] + if series_df.empty: + error_message = ( + f"No series found with the SeriesInstanceUID '{seriesInstanceUID}'." ) - if not quiet: - print(process.stderr) - if process.returncode == 0: - logger.debug(f"Successfully downloaded files to {downloadDir}") - else: - logger.error("Failed to download files.") + raise ValueError(error_message) + + series_info = series_df.iloc[0] + series_url = series_info["series_aws_url"] + series_size_MB = series_info["series_size_MB"] + + logger.debug("AWS Bucket Location: " + series_url) + + # Start downloading series files using subprocess + download_thread = threading.Thread( + target=self._download_series_process, args=(series_url, downloadDir) + ) + download_thread.start() + + # Track progress using tqdm + track_thread = threading.Thread( + target=self._track_download_progress, + args=(series_size_MB, downloadDir, download_thread), + ) + track_thread.start() + + # Wait for the download thread to finish + download_thread.join() + track_thread.join() + + logger.debug(f"Successfully downloaded files to {downloadDir}") def get_series_file_URLs(self, seriesInstanceUID): """ @@ -444,6 +506,221 @@ def get_viewer_URL( return viewer_url + def _get_series_size_from_crdc_series_uuid( + self, crdc_series_instance_uuid: str + ) -> float: + """ + Retrieves the size of a series from the index based on the given CRDC series instance UUID. + As the index does only contains aws series urls, there is no direct way to + get series size from a gcs url. However this function levarages the + fact that both gcs and aws urls share the same folder name which is + crdc series instance uuid. + + Args: + crdc_series_instance_uuid (str): The UUID of the CRDC series instance. + + Returns: + float: The size of the series in MB. + """ + index = self.index + series_size_sql = f""" + SELECT + series_size_MB + FROM + index + WHERE + series_aws_url LIKE '%{crdc_series_instance_uuid}%' + """ + return duckdb.query(series_size_sql).to_df().series_size_MB.iloc[0] + + def _validate_manifest_and_get_download_size( + self, manifestFile: str + ) -> tuple[float, str]: + """ + Validates the manifest file by checking the URLs and their availability. + The function reads the manifest file line by line. For each line, it checks if + the URL is valid and accessible. + Uses the s5cmd to check the availability of the URLs in both AWS and GCP. + If the URL is not accessible in either AWS or GCP, it raises a ValueError. + In addition it also calculates the total size of all series in the manifest file. + Args: + manifestFile (str): The path to the manifest file. + Returns: + total_size (float): The total size of all series in the manifest file. + endpoint_to_use (str): The endpoint URL to use (either AWS or GCP). + Raises: + ValueError: If the manifest file does not exist, if any URL in the manifest file is invalid, or if any URL is inaccessible in both AWS and GCP. + Exception: If the manifest contains URLs from both AWS and GCP. + """ + if not os.path.exists(manifestFile): + raise ValueError("Manifest does not exist.") + + endpoint_to_use = None + aws_found = False + gcp_found = False + total_size = 0 + + with open(manifestFile) as f: + for line in f: + if not line.startswith("#"): + series_folder_pattern = r"(s3:\/\/.*)\/\*" + match = re.search(series_folder_pattern, line) + if match is None: + raise ValueError("Invalid URL format in manifest file.") + folder_url = match.group(1) + + # Extract CRDC UUID from the line + crdc_series_uuid_pattern = r"(?:.*?\/){3}([^\/?#]+)" + match_uuid = re.search(crdc_series_uuid_pattern, line) + if match_uuid is None: + raise ValueError("Invalid URL format in manifest file.") + crdc_series_uuid = match_uuid.group(1) + + # Check AWS endpoint + cmd = [ + "s5cmd", + "--no-sign-request", + "--endpoint-url", + aws_endpoint_url, + "ls", + folder_url, + ] + process = subprocess.run( + cmd, capture_output=True, text=True, check=False + ) + if process.stderr and process.stderr.startswith("ERROR"): + # Check GCP endpoint + cmd = [ + "s5cmd", + "--no-sign-request", + "--endpoint-url", + gcp_endpoint_url, + "ls", + folder_url, + ] + process = subprocess.run( + cmd, capture_output=True, text=True, check=False + ) + if process.stderr and process.stderr.startswith("ERROR"): + error_message = f"Manifest contains invalid or inaccessible URLs. Please check line '{line}'" + raise ValueError(error_message) + else: + if aws_found: + raise Exception( + "The manifest contains URLs from both AWS and GCP. Please use only one provider." + ) + endpoint_to_use = gcp_endpoint_url + gcp_found = True + else: + if gcp_found: + raise Exception( + "The manifest contains URLs from both AWS and GCP. Please use only one provider." + ) + endpoint_to_use = aws_endpoint_url + aws_found = True + + # Get the size of the series + series_size = self._get_series_size_from_crdc_series_uuid( + crdc_series_uuid + ) + total_size += series_size + if not endpoint_to_use: + raise ValueError("No valid URLs found in the manifest.") + + return total_size, endpoint_to_use + + def _download_manifest_process( + self, manifestFile: str, downloadDir: str, endpoint_to_use: str, quiet: bool + ) -> int: + """ + Download manifest files using subprocess on a thread. + """ + # Create the command to download files + cmd = [ + "s5cmd", + "--no-sign-request", + "--endpoint-url", + endpoint_to_use, + "run", + manifestFile, + ] + + # Run the command + process = subprocess.Popen( + cmd, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + universal_newlines=not quiet, + ) + + process.communicate() # Wait for the process to finish + + if process.returncode == 0: + logger.debug(f"Successfully downloaded manifest to {downloadDir}") + else: + logger.error("Failed to download manifest.") + + return process.returncode + + def download_from_manifest( + self, manifestFile: str, downloadDir: str, quiet: bool = True + ) -> None: + """ + Download the manifest file. In a series of steps, the manifest file + is first validated to ensure every line contains a valid urls. It then + gets the total size to be downloaded and runs download process on one + thread and download progress on another thread. + + Args: + manifestFile (str): The path to the manifest file. + downloadDir (str): The directory to download the files to. + quiet (bool, optional): If True, suppresses the output of the subprocess. Defaults to True. + + Raises: + ValueError: If the download directory does not exist. + """ + total_size, endpoint_to_use = self._validate_manifest_and_get_download_size( + manifestFile + ) + + downloadDir = os.path.abspath(downloadDir).replace("\\", "/") + if not os.path.exists(downloadDir): + raise ValueError("Download directory does not exist.") + + # Create a temporary manifest file with updated destination directories + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_manifest_file: + with open(manifestFile) as f: + for line in f: + if not line.startswith("#"): + pattern = r"s3:\/\/.*\*" + match = re.search(pattern, line) + if match is None: + raise ValueError( + "Could not find the bucket URL in the first line of the manifest file." + ) + folder_url = match.group(0) + temp_manifest_file.write( + " sync " + folder_url + " " + downloadDir + "\n" + ) + + # Start downloading manifest files using subprocess + download_thread = threading.Thread( + target=self._download_manifest_process, + args=(temp_manifest_file.name, downloadDir, endpoint_to_use, quiet), + ) + download_thread.start() + + # Track progress using tqdm + track_thread = threading.Thread( + target=self._track_download_progress, + args=(total_size, downloadDir, download_thread), + ) + track_thread.start() + + # Wait for the download thread to finish + download_thread.join() + track_thread.join() + def download_from_selection( self, downloadDir, @@ -532,111 +809,6 @@ def download_from_selection( ) self.download_from_manifest(manifest_file, downloadDir) - def download_from_manifest(self, manifestFile, downloadDir, quiet=True): - """Download the files corresponding to the manifest file from IDC. The manifest file should be a text file with each line containing the s5cmd command to download the file. The URLs in the file must correspond to those in the AWS buckets! - - Args: - manifest_file: string containing the path to the manifest file - downloadDir: string containing the path to the directory to download the files to - - Returns: - - Raises: - """ - - downloadDir = os.path.abspath(downloadDir).replace("\\", "/") - - if not os.path.exists(downloadDir): - raise ValueError("Download directory does not exist.") - if not os.path.exists(manifestFile): - raise ValueError("Manifest does not exist.") - - # open manifest_file and read the first line that does not start from '#' - with open(manifestFile) as f: - for line in f: - if not line.startswith("#"): - break - pattern = r"(s3:\/\/.*)\/\*" - match = re.search(pattern, line) - if match is None: - logger.error( - "Could not find the bucket URL in the first line of the manifest file." - ) - return - folder_url = match.group(1) - - cmd = [ - self.s5cmdPath, - "--no-sign-request", - "--endpoint-url", - aws_endpoint_url, - "ls", - folder_url, - ] - process = subprocess.run(cmd, capture_output=True, text=True, check=False) - # check if output starts with ERROR - if process.stderr and process.stderr.startswith("ERROR"): - logger.debug( - "Folder not available in AWS. Checking in Google Cloud Storage." - ) - - cmd = [ - self.s5cmdPath, - "--no-sign-request", - "--endpoint-url", - gcp_endpoint_url, - "ls", - folder_url, - ] - process = subprocess.run(cmd, capture_output=True, text=True, check=False) - if process.stderr and process.stdout.startswith("ERROR"): - logger.debug( - "Folder not available in GCP. Manifest appears to be invalid." - ) - raise ValueError - else: - endpoint_to_use = gcp_endpoint_url - else: - endpoint_to_use = aws_endpoint_url - - # create an updated manifest to include the specified destination directory - with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_manifest_file: - with open(manifestFile) as f: - for line in f: - if not line.startswith("#"): - pattern = r"s3:\/\/.*\*" - match = re.search(pattern, line) - if folder_url is None: - logger.error( - "Could not find the bucket URL in the first line of the manifest file." - ) - return - folder_url = match.group(0) - temp_manifest_file.write( - " cp " + folder_url + " " + downloadDir + "\n" - ) - - cmd = [ - self.s5cmdPath, - "--no-sign-request", - "--endpoint-url", - endpoint_to_use, - "run", - temp_manifest_file.name, - ] - - logger.debug("Running command: %s", " ".join(cmd)) - process = subprocess.run( - cmd, capture_output=(not quiet), text=(not quiet), check=False - ) - logger.debug(process.stderr) - logger.debug(process.stdout) - if process.returncode == 0: - logger.debug(f"Successfully downloaded files to {downloadDir}") - logger.debug("Downloaded files: " + "\n".join(os.listdir(downloadDir))) - else: - logger.error("Failed to download files.") - def sql_query(self, sql_query): """Execute SQL query against the table in the index using duckdb. diff --git a/queries/idc_index.sql b/queries/idc_index.sql deleted file mode 100644 index 57475e1c..00000000 --- a/queries/idc_index.sql +++ /dev/null @@ -1,34 +0,0 @@ -SELECT - # collection level attributes - ANY_VALUE(collection_id) AS collection_id, - ANY_VALUE(source_DOI) AS source_DOI, - # patient level attributes - ANY_VALUE(PatientID) AS PatientID, - ANY_VALUE(PatientAge) AS PatientAge, - ANY_VALUE(PatientSex) AS PatientSex, - # study level attributes - ANY_VALUE(StudyInstanceUID) AS StudyInstanceUID, - ANY_VALUE(StudyDate) AS StudyDate, - ANY_VALUE(StudyDescription) AS StudyDescription, - ANY_VALUE(dicom_curated.BodyPartExamined) AS BodyPartExamined, - # series level attributes - SeriesInstanceUID, - ANY_VALUE(Modality) AS Modality, - ANY_VALUE(Manufacturer) AS Manufacturer, - ANY_VALUE(ManufacturerModelName) AS ManufacturerModelName, - ANY_VALUE(SAFE_CAST(SeriesDate AS STRING)) AS SeriesDate, - ANY_VALUE(SeriesDescription) AS SeriesDescription, - ANY_VALUE(SeriesNumber) AS SeriesNumber, - COUNT(dicom_all.SOPInstanceUID) AS instanceCount, - ANY_VALUE(license_short_name) as license_short_name, - # download related attributes - ANY_VALUE(CONCAT("s3://", aws_bucket, "/", crdc_series_uuid, "/*")) AS series_aws_url, - ROUND(SUM(SAFE_CAST(instance_size AS float64))/1000000, 2) AS series_size_MB, -FROM - bigquery-public-data.idc_current.dicom_all AS dicom_all -JOIN - bigquery-public-data.idc_current.dicom_metadata_curated AS dicom_curated -ON - dicom_all.SOPInstanceUID = dicom_curated.SOPInstanceUID -GROUP BY - SeriesInstanceUID