From 8eb4ec0dff2d0fa5ac9b79517ddb95810621d3e4 Mon Sep 17 00:00:00 2001 From: mjq2020 Date: Fri, 22 Dec 2023 10:56:51 +0000 Subject: [PATCH] add: init import --- sscma/datasets/transforms/__init__.py | 3 ++- sscma/datasets/transforms/wrappers.py | 2 +- tools/export.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sscma/datasets/transforms/__init__.py b/sscma/datasets/transforms/__init__.py index 797403bb..d9d59dc8 100644 --- a/sscma/datasets/transforms/__init__.py +++ b/sscma/datasets/transforms/__init__.py @@ -1,4 +1,5 @@ from .formatting import PackSensorInputs from .loading import LoadSensorFromFile +from .wrappers import MutiBranchPipe -__all__ = ['PackSensorInputs', 'LoadSensorFromFile'] +__all__ = ['PackSensorInputs', 'LoadSensorFromFile', 'MutiBranchPipe'] diff --git a/sscma/datasets/transforms/wrappers.py b/sscma/datasets/transforms/wrappers.py index 92bc17a1..edea1e4f 100644 --- a/sscma/datasets/transforms/wrappers.py +++ b/sscma/datasets/transforms/wrappers.py @@ -17,7 +17,7 @@ def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: multi_results[branch] = {'inputs': None, 'data_samples': None} for branch, pipeline in self.branch_pipelines.items(): branch_results = pipeline(copy.deepcopy(results)) - if branch == 'unsup_teacher': + if branch == self.piece_key: results['img'] = branch_results['inputs'].permute(1, 2, 0).cpu().numpy() # If one branch pipeline returns None, # it will sample another data from dataset. diff --git a/tools/export.py b/tools/export.py index a077123d..0bf443b5 100644 --- a/tools/export.py +++ b/tools/export.py @@ -15,7 +15,7 @@ import sscma.evaluation # noqa import sscma.models # noqa import sscma.visualization # noqa -from sscma.utils.check import check_lib +from sscma.utils.check import check_lib # noqa def parse_args():