diff --git a/responsibleai/tests/counterfactual/test_counterfactual_advanced_features.py b/responsibleai/tests/counterfactual/test_counterfactual_advanced_features.py index f80911a757..9c344af3cf 100644 --- a/responsibleai/tests/counterfactual/test_counterfactual_advanced_features.py +++ b/responsibleai/tests/counterfactual/test_counterfactual_advanced_features.py @@ -4,6 +4,7 @@ import os +import numpy as np import pytest from rai_test_utils.models.lightgbm import create_lightgbm_classifier @@ -21,10 +22,15 @@ class TestCounterfactualAdvancedFeatures(object): @pytest.mark.parametrize('vary_all_features', [True, False]) @pytest.mark.parametrize('feature_importance', [True, False]) + @pytest.mark.parametrize('encode_target_as_strings', [True, False]) def test_counterfactual_vary_features( - self, vary_all_features, feature_importance): + self, vary_all_features, feature_importance, + encode_target_as_strings): X_train, X_test, y_train, y_test, feature_names, _ = \ create_iris_data() + if encode_target_as_strings: + y_train = y_train.astype(str) + y_test = y_test.astype(str) model = create_lightgbm_classifier(X_train, y_train) X_train['target'] = y_train @@ -50,6 +56,26 @@ def test_counterfactual_vary_features( cf_obj = rai_insights.counterfactual.get()[0] assert cf_obj is not None + for index in range(0, len(cf_obj.cf_examples_list)): + if encode_target_as_strings: + assert isinstance( + cf_obj.cf_examples_list[ + index].test_instance_df['target'].values[0], str) + else: + assert isinstance( + cf_obj.cf_examples_list[ + index].test_instance_df['target'].values[0], np.int32) + assert cf_obj.cf_examples_list[ + index].test_instance_df['target'].values[0] in set(y_train) + + cf_target_array = cf_obj.cf_examples_list[0].final_cfs_df[ + 'target'].values + for inner_index in range(0, 10): + if encode_target_as_strings: + assert isinstance(cf_target_array[inner_index], str) + else: + assert isinstance(cf_target_array[inner_index], np.int32) + assert cf_target_array[inner_index] in set(y_train) @pytest.mark.parametrize('feature_importance', [True, False]) def test_counterfactual_permitted_range(self, feature_importance):