Skip to content

Commit

Permalink
reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Reza committed Mar 18, 2024
1 parent d68d48f commit e68b652
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 51 deletions.
10 changes: 4 additions & 6 deletions pyalfe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyalfe.data_structure import DefaultALFEDataDir, Modality
from pyalfe.image_processing import Convert3DProcessor, NilearnProcessor
from pyalfe.image_registration import GreedyRegistration, AntsRegistration
from pyalfe.inference import NNUnet, NNUnetV2
from pyalfe.inference import NNUnetV2
from pyalfe.models import MODELS_PATH
from pyalfe.pipeline import PyALFEPipelineRunner
from pyalfe.tasks.initialization import Initialization
Expand Down Expand Up @@ -58,9 +58,7 @@ class Container(containers.DeclarativeContainer):
NNUnetV2,
model_dir=str(
MODELS_PATH.joinpath(
'nnunetv2',
'Dataset502_SS',
'nnUNetTrainer__nnUNetPlans__3d_fullres'
'nnunetv2', 'Dataset502_SS', 'nnUNetTrainer__nnUNetPlans__3d_fullres'
)
),
folds=(2,),
Expand All @@ -84,7 +82,7 @@ class Container(containers.DeclarativeContainer):
MODELS_PATH.joinpath(
'nnunetv2',
'Dataset503_Enhancement',
'nnUNetTrainer__nnUNetPlans__3d_fullres'
'nnUNetTrainer__nnUNetPlans__3d_fullres',
)
),
folds=(0,),
Expand All @@ -96,7 +94,7 @@ class Container(containers.DeclarativeContainer):
MODELS_PATH.joinpath(
'nnunetv2',
'Dataset510_Tissue_W_Prior',
'nnUNetTrainer__nnUNetPlans__3d_fullres'
'nnUNetTrainer__nnUNetPlans__3d_fullres',
)
),
folds=(3,),
Expand Down
2 changes: 0 additions & 2 deletions pyalfe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pyalfe.models import MODELS_PATH, models_url
from pyalfe.tools import C3D_PATH, GREEDY_PATH, c3d_url, greedy_url
from pyalfe.utils import download_archive, extract_binary_from_archive
from pyalfe.utils.archive import extract_tar

DEFAULT_CFG = os.path.expanduser(os.path.join('~', '.config', 'pyalfe', 'config.ini'))
# importlib.resources.files('pyalfe').joinpath('config.ini')
Expand Down Expand Up @@ -59,7 +58,6 @@ def download(assets):
click.print(f'asset {asset} is not recognized.')



@main.command()
@click.argument('accession')
@click.option(
Expand Down
12 changes: 6 additions & 6 deletions pyalfe/tasks/quantification.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ def get_lesion_stats(
stats[f'median_{modality_name}_signal'] = np.median(
modality_image[lesion_indices]
)
stats[
f'five_percentile_{modality_name}_signal'
] = np.percentile(modality_image[lesion_indices], 5)
stats[
f'ninety_five_percentile_{modality_name}_signal'
] = np.percentile(modality_image[lesion_indices], 95)
stats[f'five_percentile_{modality_name}_signal'] = np.percentile(
modality_image[lesion_indices], 5
)
stats[f'ninety_five_percentile_{modality_name}_signal'] = np.percentile(
modality_image[lesion_indices], 95
)

if Modality.T1 in modality_images and Modality.T1Post in modality_images:
t1_image = modality_images[Modality.T1]
Expand Down
4 changes: 1 addition & 3 deletions pyalfe/tasks/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ def run(self, accession):
)
return
if not os.path.exists(t1trim_upsampled):
self.logger.info(
'T1 trim upsampled is missing. Skipping T1Registration'
)
self.logger.info('T1 trim upsampled is missing. Skipping T1Registration')
return

template = roi_dict['template']['source']
Expand Down
32 changes: 14 additions & 18 deletions pyalfe/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
pipeline_dir: PipelineDataDir,
modality_list: list[str],
output_modality: str,
image_type_input: str= 'skullstripped',
image_type_input: str = 'skullstripped',
image_type_output: str = 'abnormal_seg',
image_type_mask: str = None,
segmentation_dir: str = 'abnormalmap',
Expand Down Expand Up @@ -303,11 +303,11 @@ def __init__(
inference_model: InferenceModel,
image_processor: ImageProcessor,
pipeline_dir: PipelineDataDir,
image_type_input: str= 'trim_upsampled',
image_type_output: str='tissue_seg',
template_name: str='Tissue',
image_type_input: str = 'trim_upsampled',
image_type_output: str = 'tissue_seg',
template_name: str = 'Tissue',
overwrite: bool = True,
):
):
super().__init__(inference_model, image_processor)
self.pipeline_dir = pipeline_dir
self.image_type_input = image_type_input
Expand All @@ -319,31 +319,27 @@ def __init__(

def run(self, accession):
t1_image_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
image_type=self.image_type_input,
)
accession,
Modality.T1,
image_type=self.image_type_input,
)
tissue_prior_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
resampling_origin=self.template_name,
resampling_target=Modality.T1,
sub_dir_name=roi_dict[self.template_name]['sub_dir']
sub_dir_name=roi_dict[self.template_name]['sub_dir'],
)
pred_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
image_type=f'{self.image_type_output}_pred'
accession, Modality.T1, image_type=f'{self.image_type_output}_pred'
)

if self.overwrite or not os.path.exists(pred_path):
self.predict([t1_image_path, tissue_prior_path], pred_path)

seg_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
image_type=self.image_type_output
accession, Modality.T1, image_type=self.image_type_output
)

if self.overwrite or not os.path.exists(seg_path):
self.post_process(pred_path, seg_path)
self.post_process(pred_path, seg_path)
12 changes: 5 additions & 7 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib.resources
import os
import pathlib
import shutil
from unittest import TestCase

Expand All @@ -11,9 +10,8 @@
from pyalfe.main import run
from tests.utils import download_and_extract

class TestIntegration(TestCase):


class TestIntegration(TestCase):
def setUp(self) -> None:
self.test_dir = os.path.join('/tmp', 'integration_test')

Expand All @@ -22,7 +20,9 @@ def tearDown(self) -> None:

def test_run(self):

test_data_url = 'https://github.com/reghbali/pyalfe-test-data/archive/master.zip'
test_data_url = (
'https://github.com/reghbali/pyalfe-test-data/archive/master.zip'
)
test_data_dir_name = 'pyalfe-test-data-main'
accession = 'UPENNGBM0000511'

Expand All @@ -39,9 +39,7 @@ def test_run(self):
Modality.ADC,
]
targets = [Modality.T1Post, Modality.FLAIR]
pipeline_dir = DefaultALFEDataDir(
output_dir=output_dir, input_dir=input_dir
)
pipeline_dir = DefaultALFEDataDir(output_dir=output_dir, input_dir=input_dir)

download_and_extract(test_data_url, self.test_dir)

Expand Down
13 changes: 5 additions & 8 deletions tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,15 @@ def test_run(self):
input_path = self.pipeline_dir.get_output_image(
accession, Modality.T1, image_type=task.image_type_input
)
prior_path = self.pipeline_dir.get_output_image(
prior_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
resampling_origin=task.template_name,
resampling_target=Modality.T1,
sub_dir_name=roi_dict[task.template_name]['sub_dir']
sub_dir_name=roi_dict[task.template_name]['sub_dir'],
)
output_path = self.pipeline_dir.get_output_image(
accession,
Modality.T1,
image_type=task.image_type_output
accession, Modality.T1, image_type=task.image_type_output
)
shutil.copy(
os.path.join('tests', 'data', 'brats10', 'BraTS19_2013_10_1_t1.nii.gz'),
Expand Down Expand Up @@ -464,13 +462,12 @@ def test_run(self):
accession, Modality.T1, image_type='trim_upsampled'
)
shutil.copy(
os.path.join('tests', 'data', 'brainomics02', 'anat_t1.nii.gz'),
input_image
os.path.join('tests', 'data', 'brainomics02', 'anat_t1.nii.gz'), input_image
)
Convert3DProcessor.binarize(input_image, input_mask)
shutil.copy(
os.path.join('tests', 'data', 'brainomics02', 'anat_t1.nii.gz'),
input_image_trim_upsampled
input_image_trim_upsampled,
)
task.run(accession)

Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def download_and_extract(url: str, dest_dir: str, archive_name: str = None):
file.write(response.content)

shutil.unpack_archive(archive_file_path, dest_dir)

0 comments on commit e68b652

Please sign in to comment.