Skip to content

Commit

Permalink
added generative text explainer (#2516)
Browse files Browse the repository at this point in the history
* added generative text explainer, modified setup.py

Signed-off-by: Mohsin Shah <[email protected]>

* isort fix

Signed-off-by: Mohsin Shah <[email protected]>

* improved error handling

Signed-off-by: Mohsin Shah <[email protected]>

* explainer now works without context and questions columns

Signed-off-by: Mohsin Shah <[email protected]>

---------

Signed-off-by: Mohsin Shah <[email protected]>
  • Loading branch information
mohsinposts authored Jan 30, 2024
1 parent 13e1782 commit 7aa72fb
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@
Tokens)
from responsibleai_text.utils.question_answering import QAPredictor

try:
from interpret_text.generative.lime_tools.explainers import \
LocalExplanationSentenceEmbedder
interpret_text_explainers_installed = True
except ImportError:
interpret_text_explainers_installed = False

try:
from interpret_text.generative.model_lib.openai_tooling import ChatOpenAI
interpret_text_openai_tooling_installed = True
except ImportError:
interpret_text_openai_tooling_installed = False

try:
from sentence_transformers import SentenceTransformer
sentence_transformers_installed = True
except ImportError:
sentence_transformers_installed = False


CONTEXT = QuestionAnsweringFields.CONTEXT
QUESTIONS = QuestionAnsweringFields.QUESTIONS
SEP = Tokens.SEP
Expand All @@ -42,6 +62,7 @@
MODEL = Metadata.MODEL
EXPLANATION = '_explanation'
TASK_TYPE = '_task_type'
PROMPT = 'prompt'


class ExplainerManager(BaseManager):
Expand Down Expand Up @@ -74,10 +95,13 @@ def __init__(self, model: Any, evaluation_examples: pd.DataFrame,
"""
self._model = model
self._target_column = target_column
if not isinstance(target_column, list):
if not isinstance(target_column, (list, type(None))):
target_column = [target_column]
self._evaluation_examples = \
evaluation_examples.drop(columns=target_column)
if target_column is None:
self._evaluation_examples = evaluation_examples
else:
self._evaluation_examples = \
evaluation_examples.drop(columns=target_column)
self._is_run = False
self._is_added = False
self._features = list(self._evaluation_examples.columns)
Expand Down Expand Up @@ -131,6 +155,73 @@ def compute(self):
eval_examples.append(question + SEP + context)
self._explanation = [explainer_start(eval_examples),
explainer_end(eval_examples)]
elif self._task_type == ModelTask.GENERATIVE_TEXT:
if not interpret_text_explainers_installed:
error = (
"The required module"
"'interpret_text.generative.lime_tools.explainers' "
"is not installed."
)
raise RuntimeError(error)
if not interpret_text_openai_tooling_installed:
error = (
"The required module"
"'interpret_text.generative.model_lib.openai_tooling' "
"is not installed."
)
raise RuntimeError(error)
if not sentence_transformers_installed:
error = (
"The required package"
"'sentence_transformers' "
"is not installed."
)
raise RuntimeError(error)

if CONTEXT in self._evaluation_examples.columns and \
QUESTIONS in self._evaluation_examples.columns:
context = self._evaluation_examples[CONTEXT]
questions = self._evaluation_examples[QUESTIONS]
eval_examples = []
for context, question in zip(context, questions):
eval_examples.append(question + SEP + context)
elif PROMPT in self._evaluation_examples.columns:
eval_examples = self._evaluation_examples[PROMPT].tolist()
else:
raise ValueError(
"Neither 'context'/'questions' nor 'prompt' columns "
"are present in the evaluation_examples DataFrame"
)
sentence_embedder = SentenceTransformer('all-MiniLM-L6-v2')
explainer = LocalExplanationSentenceEmbedder(
sentence_embedder=sentence_embedder,
perturbation_model="removal",
partition_fn="sentences",
progress_bar=None)
max_completion = 50 # Define max tokens for the completion

api_settings = {
"api_type": self._model.model.api_type,
"api_base": self._model.model.api_base,
"api_version": self._model.model.api_version,
"api_key": self._model.model.api_key
}
model_wrapped = ChatOpenAI(
engine=self._model.model.engine,
encoding="cl100k_base",
api_settings=api_settings)
completions = model_wrapped.sample(
eval_examples, max_new_tokens=max_completion)

explanation = []
for i, completion in enumerate(completions):
attribution, parts = explainer.attribution(model_wrapped,
eval_examples[i],
completion,
)
explanation.append((attribution, parts))

self._explanation = explanation
else:
raise ValueError("Unknown task type: {}".format(self._task_type))

Expand Down
4 changes: 4 additions & 0 deletions responsibleai_text/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
'bert_score',
'nltk',
'rouge_score'
],
"generative_text": [
'interpret_text',
'sentence_transformers'
]
}
setuptools.setup(
Expand Down

0 comments on commit 7aa72fb

Please sign in to comment.