forked from nebula-chat-fork-originals/rasa
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6fd3fff
commit ac764ec
Showing
3 changed files
with
128 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Implement a new interface `run_inference` inside `RasaModel` which performs batch inferencing through tensorflow models. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import pytest | ||
from typing import Dict, Text, Union | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from rasa.utils.tensorflow.models import _merge_batch_outputs, RasaModel | ||
from rasa.utils.tensorflow.model_data import RasaModelData | ||
from rasa.shared.constants import DIAGNOSTIC_DATA | ||
from rasa.utils.tensorflow.model_data import FeatureArray | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"existing_outputs, new_batch_outputs, expected_output", | ||
[ | ||
( | ||
{"a": np.array([1, 2]), "b": np.array([3, 1])}, | ||
{"a": np.array([5, 6]), "b": np.array([2, 4])}, | ||
{"a": np.array([1, 2, 5, 6]), "b": np.array([3, 1, 2, 4])}, | ||
), | ||
( | ||
{}, | ||
{"a": np.array([5, 6]), "b": np.array([2, 4])}, | ||
{"a": np.array([5, 6]), "b": np.array([2, 4])}, | ||
), | ||
( | ||
{"a": np.array([1, 2]), "b": {"c": np.array([3, 1])}}, | ||
{"a": np.array([5, 6]), "b": {"c": np.array([2, 4])}}, | ||
{"a": np.array([1, 2, 5, 6]), "b": {"c": np.array([3, 1, 2, 4])}}, | ||
), | ||
], | ||
) | ||
def test_merging_batch_outputs( | ||
existing_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]], | ||
new_batch_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]], | ||
expected_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]], | ||
): | ||
|
||
predicted_output = _merge_batch_outputs(existing_outputs, new_batch_outputs) | ||
|
||
def test_equal_dicts( | ||
dict1: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]], | ||
dict2: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]], | ||
): | ||
assert dict2.keys() == dict1.keys() | ||
for key in dict1: | ||
val_1 = dict1[key] | ||
val_2 = dict2[key] | ||
assert type(val_1) == type(val_2) | ||
|
||
if isinstance(val_2, np.ndarray): | ||
assert np.array_equal(val_1, val_2) | ||
|
||
elif isinstance(val_2, dict): | ||
test_equal_dicts(val_1, val_2) | ||
|
||
test_equal_dicts(predicted_output, expected_output) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size, number_of_data_points, expected_number_of_batch_iterations", | ||
[(2, 3, 2), (1, 3, 3), (5, 3, 1),], | ||
) | ||
def test_batch_inference( | ||
batch_size: int, | ||
number_of_data_points: int, | ||
expected_number_of_batch_iterations: int, | ||
): | ||
model = RasaModel() | ||
|
||
def batch_predict(batch_in: np.ndarray): | ||
|
||
dummy_output = batch_in[0] | ||
output = { | ||
"dummy_output": dummy_output, | ||
DIAGNOSTIC_DATA: tf.constant(np.array([[1, 2]]), dtype=tf.int32), | ||
} | ||
return output | ||
|
||
# Monkeypatch batch predict so that run_inference interface can be tested | ||
model.batch_predict = batch_predict | ||
|
||
# Create dummy model data to pass to model | ||
model_data = RasaModelData( | ||
label_key="label", | ||
label_sub_key="ids", | ||
data={ | ||
"text": { | ||
"sentence": [ | ||
FeatureArray( | ||
np.random.rand(number_of_data_points, 2), | ||
number_of_dimensions=2, | ||
), | ||
] | ||
} | ||
}, | ||
) | ||
output = model.run_inference(model_data, batch_size=batch_size) | ||
|
||
# Firstly, the number of data points in dummy_output should be equal | ||
# to the number of data points sent as input. | ||
assert output["dummy_output"].shape[0] == number_of_data_points | ||
|
||
# Secondly, the number of data points inside diagnostic_data should be | ||
# equal to the number of batches passed to the model because for every | ||
# batch passed as input, it would have created a | ||
# corresponding diagnostic data entry. | ||
assert output[DIAGNOSTIC_DATA].shape == (expected_number_of_batch_iterations, 2) |