Skip to content

Commit

Permalink
fix order of operations in parse_feature and parse_label
Browse files Browse the repository at this point in the history
  • Loading branch information
MRossol committed Aug 24, 2020
1 parent 7a73b99 commit 7ec1119
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions phygnn/model_interfaces/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,8 @@ def _parse_features(self, features, names=None, process_one_hot=True,
Parsed features array normalized and with str columns converted
to one hot vectors if desired
"""
features, feature_names = self._parse_data(features, names=names)

if len(features.shape) != 2:
msg = ('{} can only use 2D data as input!'
.format(self.__class__.__name__))
Expand All @@ -747,8 +749,6 @@ def _parse_features(self, features, names=None, process_one_hot=True,
logger.error(msg)
raise RuntimeError(msg)

features, feature_names = self._parse_data(features, names=names)

if self._feature_names is None:
self._feature_names = feature_names
elif self.feature_names != feature_names:
Expand Down Expand Up @@ -786,6 +786,8 @@ def _parse_labels(self, labels, names=None, normalize=True):
labels : ndarray
Parsed labels array, normalized if desired
"""
labels, label_names = self._parse_data(labels, names=names)

if self.label_names is not None:
if len(labels.shape) == 1:
n_labels = len(labels)
Expand All @@ -798,8 +800,6 @@ def _parse_labels(self, labels, names=None, normalize=True):
logger.error(msg)
raise RuntimeError(msg)

labels, label_names = self._parse_data(labels, names=names)

if self._label_names is None:
self._label_names = label_names
elif self.label_names != label_names:
Expand Down

0 comments on commit 7ec1119

Please sign in to comment.