Skip to content

Commit

Permalink
add RandomForestModel tests
Browse files Browse the repository at this point in the history
patch bugs
update version
  • Loading branch information
MRossol committed Aug 25, 2020
1 parent 7ec1119 commit e1e2cf2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
37 changes: 29 additions & 8 deletions phygnn/model_interfaces/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
Base Model Interface
"""
from abc import ABC
import logging
import numpy as np
import pandas as pd
Expand All @@ -12,7 +13,7 @@
logger = logging.getLogger(__name__)


class ModelBase:
class ModelBase(ABC):
"""
Base Model Interface
"""
Expand Down Expand Up @@ -381,6 +382,28 @@ def _parse_data(data, names=None):

return data, names

@staticmethod
def _get_item_number(arr):
"""
Get number of items in array (labels or features)
Parameters
----------
arr : ndarray
1 or 2D array
Returns
-------
n : int
Number of items
"""
if len(arr.shape) == 1:
n = 1
else:
n = arr.shape[1]

return n

def get_norm_params(self, names):
"""
Get means and stdevs for given feature/label names
Expand Down Expand Up @@ -527,7 +550,8 @@ def _normalize_arr(self, arr, names):
norm_arr : ndarray
Normalized features/label
"""
if len(names) != arr.shape[1]:
n_names = self._get_item_number(arr)
if len(names) != n_names:
msg = ("Number of item names ({}) does not match number of items "
"({})".format(len(names), arr.shape[1]))
logger.error(msg)
Expand Down Expand Up @@ -647,7 +671,8 @@ def _unnormalize_arr(self, arr, names):
native_arr : ndarray
Native features/label array
"""
if len(names) != arr.shape[1]:
n_names = self._get_item_number(arr)
if len(names) != n_names:
msg = ("Number of item names ({}) does not match number of items "
"({})".format(len(names), arr.shape[1]))
logger.error(msg)
Expand Down Expand Up @@ -789,11 +814,7 @@ def _parse_labels(self, labels, names=None, normalize=True):
labels, label_names = self._parse_data(labels, names=names)

if self.label_names is not None:
if len(labels.shape) == 1:
n_labels = len(labels)
else:
n_labels = labels.shape[1]

n_labels = self._get_item_number(labels)
if n_labels != len(self.label_names):
msg = ('data has {} labels but expected {}'
.format(labels.shape[1], self.label_dims))
Expand Down
9 changes: 4 additions & 5 deletions phygnn/model_interfaces/random_forest_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model, feature_names=None, label_name=None,
parameters (mean, stdev), by default None
"""
super().__init__(model, feature_names=feature_names,
label_name=label_name, norm_params=norm_params)
label_names=label_name, norm_params=norm_params)

if len(self.label_names) > 1:
msg = ("Only a single label can be supplied to {}, but {} were"
Expand Down Expand Up @@ -133,8 +133,7 @@ def train_model(self, features, label, norm_label=True, parse_kwargs=None,

features = self._parse_features(features, **parse_kwargs)

label = self._parse_data(label, normalize=norm_label,
names=True)
label = self._parse_labels(label, normalize=norm_label)

if fit_kwargs is None:
fit_kwargs = {}
Expand Down Expand Up @@ -204,10 +203,10 @@ def train(cls, features, label, norm_label=True, save_path=None,
compile_kwargs = {}

_, feature_names = cls._parse_data(features)
_, label_names = cls._parse_data(label)
_, label_name = cls._parse_data(label)

model = cls(cls.compile_model(**compile_kwargs),
feature_names=feature_names, label_names=label_names)
feature_names=feature_names, label_name=label_name)

model.train_model(features, label, norm_label=norm_label,
parse_kwargs=parse_kwargs, fit_kwargs=fit_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.0'
__version__ = '0.0.1'
31 changes: 31 additions & 0 deletions tests/test_random_forest_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Tests for basic phygnn functionality and execution.
"""
# pylint: disable=W0613
import numpy as np
import pandas as pd

from phygnn.model_interfaces.random_forest_model import RandomForestModel


N = 100
A = np.linspace(-1, 1, N)
B = np.linspace(-1, 1, N)
A, B = np.meshgrid(A, B)
A = np.expand_dims(A.flatten(), axis=1)
B = np.expand_dims(B.flatten(), axis=1)

Y = np.sqrt(A ** 2 + B ** 2)
X = np.hstack((A, B))
features = pd.DataFrame(X, columns=['a', 'b'])

Y_NOISE = Y * (1 + (np.random.random(Y.shape) - 0.5) * 0.5) + 0.1
labels = pd.DataFrame(Y_NOISE, columns=['c'])


def test_random_forest():
"""Test the RandomForestModel """
model = RandomForestModel.train(features, labels)

test_mae = np.mean(np.abs(model[X].values.ravel() - Y))
assert test_mae < 0.4

0 comments on commit e1e2cf2

Please sign in to comment.