Skip to content

Commit

Permalink
Merge pull request RasaHQ#8565 from RasaHQ/merge-2.5.x-main
Browse files Browse the repository at this point in the history
Merge 2.5.x into main
  • Loading branch information
Maxime Vdb authored Apr 28, 2021
2 parents 98666b1 + ec2a9b8 commit 3eee413
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 39 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ https://github.com/RasaHQ/rasa/tree/main/changelog/ . -->

<!-- TOWNCRIER -->

## [2.5.1] - 2021-04-28


### Bugfixes
- [#8446](https://github.com/rasahq/rasa/issues/8446): Fixed prediction for rules with multiple entities.
- [#8545](https://github.com/rasahq/rasa/issues/8545): Mitigated Matplotlib backend issue using lazy configuration
and added a more explicit error message to guide users.


## [2.5.0] - 2021-04-12


Expand Down
2 changes: 1 addition & 1 deletion docs/docs/setting-up-ci-cd.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ would be incompatible with the current production model.
## Example CI/CD pipelines

As examples, see the CI/CD pipelines for
[Sara](https://github.com/RasaHQ/rasa-demo/blob/master/.github/workflows/build_and_deploy.yml),
[Sara](https://github.com/RasaHQ/rasa-demo/blob/main/.github/workflows/continuous_integration.yml),
the Rasa assistant that you can talk to in the Rasa Docs, and
[Carbon Bot](https://github.com/RasaHQ/carbon-bot/blob/master/.github/workflows/model_ci.yml).
Both use [Github Actions](https://github.com/features/actions) as a CI/CD tool.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exclude = "((.eggs | .git | .pytest_cache | build | dist))"

[tool.poetry]
name = "rasa"
version = "2.5.0"
version = "2.5.1"
description = "Open source machine learning framework to automate text- and voice-based conversations: NLU, dialogue management, connect to Slack, Facebook, and more - Create chatbots and voice assistants"
authors = [ "Rasa Technologies GmbH <[email protected]>",]
maintainers = [ "Tom Bocklisch <[email protected]>",]
Expand Down
29 changes: 18 additions & 11 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,26 +803,33 @@ def train(
def _does_rule_match_state(rule_state: State, conversation_state: State) -> bool:
for state_type, rule_sub_state in rule_state.items():
conversation_sub_state = conversation_state.get(state_type, {})
for key, value in rule_sub_state.items():
if isinstance(value, list):
# json dumps and loads tuples as lists,
# so we need to convert them back
value = tuple(value)
for key, value_from_rules in rule_sub_state.items():
# json dumps and loads tuples as lists
if isinstance(value_from_rules, list):
# sort values before comparing
# `sorted` returns a list
value_from_rules = sorted(value_from_rules)

value_from_conversation = conversation_sub_state.get(key)
if isinstance(value_from_conversation, tuple):
# sort values before comparing
# `sorted` returns a list
value_from_conversation = sorted(value_from_conversation)

if (
# value should be set, therefore
# check whether it is the same as in the state
value
and value != SHOULD_NOT_BE_SET
and conversation_sub_state.get(key) != value
value_from_rules
and value_from_rules != SHOULD_NOT_BE_SET
and value_from_conversation != value_from_rules
) or (
# value shouldn't be set, therefore
# it should be None or non existent in the state
value == SHOULD_NOT_BE_SET
and conversation_sub_state.get(key)
value_from_rules == SHOULD_NOT_BE_SET
and value_from_conversation
# during training `SHOULD_NOT_BE_SET` is provided. Hence, we also
# have to check for the value of the slot state
and conversation_sub_state.get(key) != SHOULD_NOT_BE_SET
and value_from_conversation != SHOULD_NOT_BE_SET
):
return False

Expand Down
42 changes: 36 additions & 6 deletions rasa/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import itertools
import os
from functools import wraps

import numpy as np
from typing import List, Text, Optional, Union, Any
from typing import Any, Callable, List, Optional, Text, TypeVar, Union
import matplotlib
from matplotlib.ticker import FormatStrFormatter

Expand All @@ -14,10 +15,20 @@


def _fix_matplotlib_backend() -> None:
"""Tries to fix a broken matplotlib backend..."""
"""Tries to fix a broken matplotlib backend."""
try:
backend = matplotlib.get_backend()
except Exception: # skipcq:PYL-W0703
logger.error(
"Cannot retrieve Matplotlib backend, likely due to a compatibility "
"issue with system dependencies. Please refer to the documentation: "
"https://matplotlib.org/stable/tutorials/introductory/usage.html#backends"
)
raise

# At first, matplotlib will be initialized with default OS-specific
# available backend
if matplotlib.get_backend() == "TkAgg":
if backend == "TkAgg":
try:
# on OSX sometimes the tkinter package is broken and can't be imported.
# we'll try to import it and if it fails we will use a different backend
Expand All @@ -27,7 +38,7 @@ def _fix_matplotlib_backend() -> None:
matplotlib.use("agg")

# if no backend is set by default, we'll try to set it up manually
elif matplotlib.get_backend() is None: # pragma: no cover
elif backend is None: # pragma: no cover
try:
# If the `tkinter` package is available, we can use the `TkAgg` backend
import tkinter # noqa: 401
Expand All @@ -39,10 +50,27 @@ def _fix_matplotlib_backend() -> None:
matplotlib.use("agg")


# we call the fix as soon as this package gets imported
_fix_matplotlib_backend()
ReturnType = TypeVar("ReturnType")
FuncType = Callable[..., ReturnType]
_MATPLOTLIB_BACKEND_FIXED = False


def _needs_matplotlib_backend(func: FuncType) -> FuncType:
"""Decorator to fix matplotlib backend before calling a function."""

@wraps(func)
def inner(*args: Any, **kwargs: Any) -> ReturnType:
"""Replacement function that fixes matplotlib backend."""
global _MATPLOTLIB_BACKEND_FIXED
if not _MATPLOTLIB_BACKEND_FIXED:
_fix_matplotlib_backend()
_MATPLOTLIB_BACKEND_FIXED = True
return func(*args, **kwargs)

return inner


@_needs_matplotlib_backend
def plot_confusion_matrix(
confusion_matrix: np.ndarray,
classes: Union[np.ndarray, List[Text]],
Expand Down Expand Up @@ -117,6 +145,7 @@ def plot_confusion_matrix(
fig.savefig(output_file, bbox_inches="tight")


@_needs_matplotlib_backend
def plot_histogram(
hist_data: List[List[float]], title: Text, output_file: Optional[Text] = None
) -> None:
Expand Down Expand Up @@ -217,6 +246,7 @@ def plot_histogram(
fig.savefig(output_file, bbox_inches="tight")


@_needs_matplotlib_backend
def plot_curve(
output_directory: Text,
number_of_examples: List[int],
Expand Down
2 changes: 1 addition & 1 deletion rasa/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# this file will automatically be changed,
# do not add anything but the version number here!
__version__ = "2.5.0"
__version__ = "2.5.1"
88 changes: 69 additions & 19 deletions tests/core/policies/test_rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
RULE_ONLY_SLOTS,
RULE_ONLY_LOOPS,
)
from rasa.shared.nlu.constants import TEXT, INTENT, ACTION_NAME
from rasa.shared.nlu.constants import TEXT, INTENT, ACTION_NAME, ENTITY_ATTRIBUTE_TYPE
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import (
ActionExecuted,
Expand Down Expand Up @@ -2101,9 +2101,7 @@ def test_hide_rule_turn():
],
)
policy = RulePolicy()
policy.train(
[GREET_RULE, chitchat_story], domain, RegexInterpreter(),
)
policy.train([GREET_RULE, chitchat_story], domain, RegexInterpreter())

conversation_events = [
ActionExecuted(ACTION_LISTEN_NAME),
Expand Down Expand Up @@ -2664,9 +2662,7 @@ def test_remove_action_listen_prediction_if_contradicts_with_story():
],
)
policy = RulePolicy()
policy.train(
[rule, story], domain, RegexInterpreter(),
)
policy.train([rule, story], domain, RegexInterpreter())
prediction_source = [{PREVIOUS_ACTION: {ACTION_NAME: utter_1}}]
key = policy._create_feature_key(prediction_source)
assert key not in policy.lookup[RULES]
Expand Down Expand Up @@ -2718,9 +2714,7 @@ def test_keep_action_listen_prediction_after_predictable_action():
# prediction of action_listen should only be removed if it occurs after the first
# action (unpredictable)
with pytest.raises(InvalidRule):
policy.train(
[rule, story], domain, RegexInterpreter(),
)
policy.train([rule, story], domain, RegexInterpreter())


def test_keep_action_listen_prediction_if_last_prediction():
Expand Down Expand Up @@ -2762,9 +2756,7 @@ def test_keep_action_listen_prediction_if_last_prediction():
policy = RulePolicy()
# prediction of action_listen should only be removed if it's not the last prediction
with pytest.raises(InvalidRule):
policy.train(
[rule, story], domain, RegexInterpreter(),
)
policy.train([rule, story], domain, RegexInterpreter())


def test_keep_action_listen_prediction_if_contradicts_with_rule():
Expand Down Expand Up @@ -2807,9 +2799,7 @@ def test_keep_action_listen_prediction_if_contradicts_with_rule():
)
policy = RulePolicy()
with pytest.raises(InvalidRule):
policy.train(
[rule, other_rule], domain, RegexInterpreter(),
)
policy.train([rule, other_rule], domain, RegexInterpreter())


def test_raise_contradiction_if_rule_contradicts_with_story():
Expand Down Expand Up @@ -2850,6 +2840,66 @@ def test_raise_contradiction_if_rule_contradicts_with_story():
)
policy = RulePolicy()
with pytest.raises(InvalidRule):
policy.train(
[rule, story], domain, RegexInterpreter(),
)
policy.train([rule, story], domain, RegexInterpreter())


def test_rule_with_multiple_entities():
intent_1 = "intent_1"
entity_1 = "entity_1"
entity_2 = "entity_2"
utter_1 = "utter_1"
domain = Domain.from_yaml(
f"""
version: "2.0"
intents:
- {intent_1}
entities:
- {entity_1}
- {entity_2}
actions:
- {utter_1}
"""
)

rule = TrackerWithCachedStates.from_events(
"rule without action_listen",
domain=domain,
slots=domain.slots,
evts=[
ActionExecuted(RULE_SNIPPET_ACTION_NAME),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
intent={"name": intent_1},
entities=[
{ENTITY_ATTRIBUTE_TYPE: entity_1},
{ENTITY_ATTRIBUTE_TYPE: entity_2},
],
),
ActionExecuted(utter_1),
ActionExecuted(ACTION_LISTEN_NAME),
],
is_rule_tracker=True,
)
policy = RulePolicy()
policy.train([rule], domain, RegexInterpreter())

# the order of entities in the entities list doesn't matter for prediction
conversation_events = [
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
"haha",
intent={"name": intent_1},
entities=[
{ENTITY_ATTRIBUTE_TYPE: entity_2},
{ENTITY_ATTRIBUTE_TYPE: entity_1},
],
),
]
prediction = policy.predict_action_probabilities(
DialogueStateTracker.from_events(
"casd", evts=conversation_events, slots=domain.slots
),
domain,
RegexInterpreter(),
)
assert_predicted_action(prediction, domain, utter_1)

0 comments on commit 3eee413

Please sign in to comment.