diff --git a/docs/index.md b/docs/index.md index 3f08f487..c65b7fb8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,8 +17,8 @@ Currently, five approaches are implemented, including their original hyperparame Three approaches are implemented without their original hyperparameters: -* **[ConditionalDANN][rul_adapt.approach.conditional]** by Cheng et al. (2021) -* **[ConditionalMMD][rul_adapt.approach.conditional]** by Cheng et al. (2021) +* **[ConditionalDANN][rul_adapt.approach.conditional.ConditionalDannApproach]** by Cheng et al. (2021) +* **[ConditionalMMD][rul_adapt.approach.conditional.ConditionalMmdApproach]** by Cheng et al. (2021) * **[PseudoLabels][rul_adapt.approach.pseudo_labels]** as used by Wang et al. (2022) This includes the following general approaches adapted for RUL estimation: @@ -26,7 +26,7 @@ This includes the following general approaches adapted for RUL estimation: * **Domain Adaption Neural Networks (DANN)** by Ganin et al. (2016) * **Multi-Kernel Maximum Mean Discrepancy (MMD)** by Long et al. (2015) -Each approach has an example notebook which can be found in the [examples](examples) folder. +Each approach has an example notebook which can be found in the [examples](examples/index.md) folder. ## Installation diff --git a/poetry.lock b/poetry.lock index 1651aa90..9db617ba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4132,13 +4132,13 @@ pyasn1 = ">=0.1.3" [[package]] name = "rul-datasets" -version = "0.14.0" +version = "0.14.1" description = "A collection of datasets for RUL estimation as Lightning Data Modules." optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "rul_datasets-0.14.0-py3-none-any.whl", hash = "sha256:df7d6cf599faa44c40ef13812102e8f68541df260d9a8c0695b1261dc084c3aa"}, - {file = "rul_datasets-0.14.0.tar.gz", hash = "sha256:6552eeda3404cdd7c44e537c16308a1a1176cb73f6d8ae64cbef76ef5ec5b4e0"}, + {file = "rul_datasets-0.14.1-py3-none-any.whl", hash = "sha256:626172ff3004cf45befe75e05bcd195728d33d829ccb362880363494ece2185a"}, + {file = "rul_datasets-0.14.1.tar.gz", hash = "sha256:c677149790a1aa37255c18bd8defaa82bffd2e6e19d8123678550272913b466f"}, ] [package.dependencies] @@ -4909,4 +4909,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "4983b2cc34da2fb4a75bb15d48a1227ce635040e1da10e2d9c7aec265651725d" +content-hash = "f1696e97b39a98629f7b50c75fb2b77cdd94e867119f10ec31b0ef516671c484" diff --git a/pyproject.toml b/pyproject.toml index f69e9853..148b611d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ packages = [{include = "rul_adapt"}] [tool.poetry.dependencies] python = "^3.8" pytorch-lightning = ">1.8.0.post1" -rul-datasets = ">=0.14.0" +rul-datasets = ">=0.14.1" tqdm = "^4.62.2" hydra-core = "^1.3.1" pywavelets = "^1.4.1" diff --git a/rul_adapt/model/__init__.py b/rul_adapt/model/__init__.py index e92000e3..33cc4993 100644 --- a/rul_adapt/model/__init__.py +++ b/rul_adapt/model/__init__.py @@ -27,3 +27,4 @@ from .head import FullyConnectedHead from .rnn import LstmExtractor, GruExtractor from .wrapper import ActivationDropoutWrapper +from .two_stage import TwoStageExtractor diff --git a/rul_adapt/model/two_stage.py b/rul_adapt/model/two_stage.py new file mode 100644 index 00000000..949632c6 --- /dev/null +++ b/rul_adapt/model/two_stage.py @@ -0,0 +1,59 @@ +import torch +from torch import nn + + +class TwoStageExtractor(nn.Module): + """This module combines two feature extractors into a single network. + + The input data is expected to be of shape `[batch_size, upper_seq_len, + input_channels, lower_seq_len]`. An example would be vibration data recorded in + spaced intervals, where lower_seq_len is the length of an interval and + upper_seq_len is the window size of a sliding window over the intervals. + + The lower_stage is applied to each interval individually to extract features. + The upper_stage is then applied to the extracted features of the window. + The resulting feature vector should represent the window without the need to + manually extract features from the raw data of the intervals. + """ + + def __init__( + self, + lower_stage: nn.Module, + upper_stage: nn.Module, + ): + """ + Create a new two-stage extractor. + + The lower stage needs to take a tensor of shape `[batch_size, input_channels, + seq_len]` and return a tensor of shape `[batch_size, lower_output_units]`. The + upper stage needs to take a tensor of shape `[batch_size, upper_seq_len, + lower_output_units]` and return a tensor of shape `[batch_size, + upper_output_units]`. Args: lower_stage: upper_stage: + """ + super().__init__() + + self.lower_stage = lower_stage + self.upper_stage = upper_stage + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Apply the two-stage extractor to the input tensor. + + The input tensor is expected to be of shape `[batch_size, upper_seq_len, + input_channels, lower_seq_len]`. The output tensor will be of shape + `[batch_size, upper_output_units]`. + + Args: + inputs: the input tensor + + Returns: + an output tensor of shape `[batch_size, upper_output_units]` + """ + batch_size, upper_seq_len, input_channels, lower_seq_len = inputs.shape + inputs = inputs.reshape(-1, input_channels, lower_seq_len) + inputs = self.lower_stage(inputs) + inputs = inputs.reshape(batch_size, upper_seq_len, -1) + inputs = torch.transpose(inputs, 1, 2) + inputs = self.upper_stage(inputs) + + return inputs diff --git a/tests/test_model/test_two_stage.py b/tests/test_model/test_two_stage.py new file mode 100644 index 00000000..186b351e --- /dev/null +++ b/tests/test_model/test_two_stage.py @@ -0,0 +1,46 @@ +import pytest +import torch + +from rul_adapt.model import TwoStageExtractor + + +@pytest.fixture() +def extractor(): + lower_stage = torch.nn.Sequential( + torch.nn.Conv1d(3, 8, 3), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(8 * 62, 8), + ) + upper_stage = torch.nn.Sequential( + torch.nn.Conv1d(8, 8, 2), + torch.nn.ReLU(), + torch.nn.Flatten(), + torch.nn.Linear(8 * 3, 8), + ) + extractor = TwoStageExtractor(lower_stage, upper_stage) + + return extractor + + +@pytest.fixture() +def inputs(): + return torch.rand(16, 4, 3, 64) + + +def test_forward_shape(inputs, extractor): + outputs = extractor(inputs) + + assert outputs.shape == (16, 8) + + +def test_forward_upper_lower_interaction(inputs, extractor): + one_sample = inputs[3] + + lower_outputs = extractor.lower_stage(one_sample) + upper_outputs = extractor.upper_stage( + torch.transpose(lower_outputs.unsqueeze(0), 1, 2) + ) + outputs = extractor(inputs) + + assert torch.allclose(upper_outputs, outputs[3])