rasa培训课程:Rasa微服务Action自定义及Slot Validation详解
第3课:Rasa微服务Action自定义及Slot Validation详解
做Rasa智能对话机器人,微服务Action自定义开发是绕不开的内容,只要做开发,一定会用到Actions、Tracker、Dispatcher、Events
Action类是任何自定义操作的基类。要定义自定义操作,请创建操作类的子类,并覆盖两个必需的方法:name和run。当收到运行操作的请求时,Action服务器将根据其name方法的返回值调用操作。
官网文档的一个示例:
class MyCustomAction(Action):
def name(self) -> Text:
return "action_name"
async def run(
self, dispatcher, tracker: Tracker, domain: Dict[Text, Any],
) -> List[Dict[Text, Any]]:
return []
rasa server的Action是一个普通的类
class Action:
"""Next action to be taken in response to a dialogue state."""
def name(self) -> Text:
"""Unique identifier of this simple action."""
raise NotImplementedError
async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
"""Execute the side effects of this action.
Args:
nlg: which ``nlg`` to use for response generation
output_channel: ``output_channel`` to which to send the resulting message.
tracker (DialogueStateTracker): the state tracker for the current
user. You can access slot values using
``tracker.get_slot(slot_name)`` and the most recent user
message is ``tracker.latest_message.text``.
domain (Domain): the bot's domain
Returns:
A list of :class:`rasa.core.events.Event` instances
"""
raise NotImplementedError
def __str__(self) -> Text:
"""Returns text representation of form."""
return f"{self.__class__.__name__}('{self.name()}')"
def event_for_successful_execution(
self, prediction: PolicyPrediction
) -> ActionExecuted:
"""Event which should be logged for the successful execution of this action.
Args:
prediction: Prediction which led to the execution of this event.
Returns:
Event which should be logged onto the tracker.
"""
return ActionExecuted(
self.name(),
prediction.policy_name,
prediction.max_confidence,
hide_rule_turn=prediction.hide_rule_turn,
metadata=prediction.action_metadata,
)
rasa sdk server的action类:
class Action:
"""Next action to be taken in response to a dialogue state."""
def name(self) -> Text:
"""Unique identifier of this simple action."""
raise NotImplementedError("An action must implement a name")
async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: Tracker,
domain: "DomainDict",
) -> List[Dict[Text, Any]]:
"""Execute the side effects of this action.
Args:
dispatcher: the dispatcher which is used to
send messages back to the user. Use
`dispatcher.utter_message()` for sending messages.
tracker: the state tracker for the current
user. You can access slot values using
`tracker.get_slot(slot_name)`, the most recent user message
is `tracker.latest_message.text` and any other
`rasa_sdk.Tracker` property.
domain: the bot's domain
Returns:
A dictionary of `rasa_sdk.events.Event` instances that is
returned through the endpoint
"""
raise NotImplementedError("An action must implement its run method")
def __str__(self) -> Text:
return f"Action('{self.name()}')"
reminder机器人 复写run方法,重点关注dispatcher: CollectingDispatcher, tracker: Tracker,
class ActionSetReminder(Action):
"""Schedules a reminder, supplied with the last message's entities."""
def name(self) -> Text:
return "action_set_reminder"
async def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any],
) -> List[Dict[Text, Any]]:
dispatcher.utter_message("I will remind you in 5 seconds.")
date = datetime.datetime.now() + datetime.timedelta(seconds=5)
entities = tracker.latest_message.get("entities")
reminder = ReminderScheduled(
"EXTERNAL_reminder",
trigger_date_time=date,
entities=entities,
name="my_reminder",
kill_on_user_message=False,
)
return [reminder]
executor.py的CollectingDispatcher类
class CollectingDispatcher:
"""Send messages back to user"""
def __init__(self) -> None:
self.messages: List[Dict[Text, Any]] = []
def utter_message(
self,
text: Optional[Text] = None,
image: Optional[Text] = None,
json_message: Optional[Dict[Text, Any]] = None,
template: Optional[Text] = None,
response: Optional[Text] = None,
attachment: Optional[Text] = None,
buttons: Optional[List[Dict[Text, Any]]] = None,
elements: Optional[List[Dict[Text, Any]]] = None,
**kwargs: Any,
) -> None:
"""Send a text to the output channel."""
if template and not response:
response = template
warnings.warn(
"Please pass the parameter `response` instead of `template` "
"to `utter_message`. `template` will be deprecated in Rasa 3.0.0. ",
FutureWarning,
)
message = {
"text": text,
"buttons": buttons or [],
"elements": elements or [],
"custom": json_message or {},
"template": response,
"response": response,
"image": image,
"attachment": attachment,
}
message.update(kwargs)
self.messages.append(message)
# deprecated
def utter_custom_message(self, *elements: Dict[Text, Any], **kwargs: Any) -> None:
warnings.warn(
"Use of `utter_custom_message` is deprecated. "
"Use `utter_message(elements=)` instead.",
FutureWarning,
)
self.utter_message(elements=list(elements), **kwargs)
def utter_elements(self, *elements: Dict[Text, Any], **kwargs: Any) -> None:
"""Sends a message with custom elements to the output channel."""
warnings.warn(
"Use of `utter_elements` is deprecated. "
"Use `utter_message(elements=)` instead.",
FutureWarning,
)
self.utter_message(elements=list(elements), **kwargs)
def utter_button_message(
self, text: Text, buttons: List[Dict[Text, Any]], **kwargs: Any
) -> None:
"""Sends a message with buttons to the output channel."""
warnings.warn(
"Use of `utter_button_message` is deprecated. "
"Use `utter_message(text=)` instead.",
FutureWarning,
)
self.utter_message(text=text, buttons=buttons, **kwargs)
def utter_attachment(self, attachment: Text, **kwargs: Any) -> None:
"""Send a message to the client with attachments."""
warnings.warn(
"Use of `utter_attachment` is deprecated. "
"Use `utter_message(attachment=
FutureWarning,
)
self.utter_message(attachment=attachment, **kwargs)
# noinspection PyUnusedLocal
def utter_button_template(
self,
template: Text,
buttons: List[Dict[Text, Any]],
tracker: Tracker,
silent_fail: bool = False,
**kwargs: Any,
) -> None:
"""Sends a message template with buttons to the output channel."""
warnings.warn(
"Use of `utter_button_template` is deprecated. "
"Use `utter_message(template=, buttons=)` instead.",
FutureWarning,
)
self.utter_message(template=template, buttons=buttons, **kwargs)
# noinspection PyUnusedLocal
def utter_template(
self, template: Text, tracker: Tracker, silent_fail: bool = False, **kwargs: Any
) -> None:
"""Send a message to the client based on a template."""
warnings.warn(
"Use of `utter_template` is deprecated. "
"Use `utter_message(response=
FutureWarning,
)
self.utter_message(response=template, **kwargs)
def utter_custom_json(self, json_message: Dict[Text, Any], **kwargs: Any) -> None:
"""Sends custom json to the output channel."""
warnings.warn(
"Use of `utter_custom_json` is deprecated. "
"Use `utter_message(json_message=
FutureWarning,
)
self.utter_message(json_message=json_message, **kwargs)
def utter_image_url(self, image: Text, **kwargs: Any) -> None:
"""Sends url of image attachment to the output channel."""
warnings.warn(
"Use of `utter_image_url` is deprecated. "
"Use `utter_message(image=
FutureWarning,
)
self.utter_message(image=image, **kwargs)
Tracker类:
class Tracker:
"""Maintains the state of a conversation."""
@classmethod
def from_dict(cls, state: "TrackerState") -> "Tracker":
"""Create a tracker from dump."""
return Tracker(
state["sender_id"],
state.get("slots", {}),
state.get("latest_message", {}),
state.get("events", []),
state.get("paused", False),
state.get("followup_action"),
state.get("active_loop", state.get("active_form", {})),
state.get("latest_action_name"),
)
def __init__(
self,
sender_id: Text,
slots: Dict[Text, Any],
latest_message: Optional[Dict[Text, Any]],
events: List[Dict[Text, Any]],
paused: bool,
followup_action: Optional[Text],
active_loop: Dict[Text, Any],
latest_action_name: Optional[Text],
) -> None:
"""Initialize the tracker."""
# list of previously seen events
self.events = events
# id of the source of the messages
self.sender_id = sender_id
# slots that can be filled in this domain
self.slots = slots
self.followup_action = followup_action
self._paused = paused
# latest_message is `parse_data`,
# which is a dict: {"intent": UserUttered.intent,
# "entities": UserUttered.entities,
# "text": text}
self.latest_message = latest_message if latest_message else {}
self.active_loop = active_loop
self.latest_action_name = latest_action_name
@property
def active_form(self) -> Dict[Text, Any]:
warnings.warn(
"Use of `active_form` is deprecated. Please use `active_loop insteaad.",
DeprecationWarning,
)
return self.active_loop
def current_state(self) -> Dict[Text, Any]:
"""Return the current tracker state as an object."""
if len(self.events) > 0:
latest_event_time = self.events[-1].get("timestamp")
else:
latest_event_time = None
return {
"sender_id": self.sender_id,
"slots": self.slots,
"latest_message": self.latest_message,
"latest_event_time": latest_event_time,
"paused": self.is_paused(),
"events": self.events,
"latest_input_channel": self.get_latest_input_channel(),
"active_loop": self.active_loop,
"latest_action_name": self.latest_action_name,
}
def current_slot_values(self) -> Dict[Text, Any]:
"""Return the currently set values of the slots"""
return self.slots
def get_slot(self, key) -> Optional[Any]:
"""Retrieves the value of a slot."""
if key in self.slots:
return self.slots[key]
else:
logger.info(f"Tried to access non existent slot '{key}'.")
return None
def get_latest_entity_values(
self,
entity_type: Text,
entity_role: Optional[Text] = None,
entity_group: Optional[Text] = None,
) -> Iterator[Text]:
"""Get entity values found for the passed entity type and optional role and
group in latest message.
If you are only interested in the first entity of a given type use
`next(tracker.get_latest_entity_values("my_entity_name"), None)`.
If no entity is found `None` is the default result.
Args:
entity_type: the entity type of interest
entity_role: optional entity role of interest
entity_group: optional entity group of interest
Returns:
List of entity values.
"""
entities = self.latest_message.get("entities", [])
return (
x.get("value")
for x in entities
if x.get("entity") == entity_type
and x.get("group") == entity_group
and x.get("role") == entity_role
)
def get_latest_input_channel(self) -> Optional[Text]:
"""Get the name of the input_channel of the latest UserUttered event"""
for e in reversed(self.events):
if e.get("event") == "user":
return e.get("input_channel")
return None
def is_paused(self) -> bool:
"""State whether the tracker is currently paused."""
return self._paused
def idx_after_latest_restart(self) -> int:
"""Return the idx of the most recent restart in the list of events.
If the conversation has not been restarted, `0` is returned.
"""
idx = 0
for i, event in enumerate(self.events):
if event.get("event") == "restart":
idx = i + 1
return idx
def events_after_latest_restart(self) -> List[dict]:
"""Return a list of events after the most recent restart."""
return list(self.events)[self.idx_after_latest_restart() :]
@property
def active_loop_name(self) -> Optional[Text]:
"""Get the name of the currently active loop.
Returns: `None` if no active loop or the name of the currently active loop.
"""
if not self.active_loop or self.active_loop.get("name") == "should_not_be_set":
return None
return self.active_loop.get("name")
def __eq__(self, other: Any) -> bool:
if isinstance(self, type(other)):
return other.events == self.events and self.sender_id == other.sender_id
else:
return False
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def copy(self) -> "Tracker":
return Tracker(
self.sender_id,
copy.deepcopy(self.slots),
copy.deepcopy(self.latest_message),
copy.deepcopy(self.events),
self._paused,
self.followup_action,
self.active_loop,
self.latest_action_name,
)
def last_executed_action_has(self, name: Text, skip: int = 0) -> bool:
last = self.get_last_event_for(
"action", exclude=[ACTION_LISTEN_NAME], skip=skip
)
return last is not None and last["name"] == name
def get_last_event_for(
self, event_type: Text, exclude: List[Text] = [], skip: int = 0
) -> Optional[Dict[Text, Any]]:
def filter_function(e: Dict[Text, Any]) -> bool:
has_instance = e["event"] == event_type
excluded = e["event"] == "action" and e["name"] in exclude
return has_instance and not excluded
filtered = filter(filter_function, reversed(self.applied_events()))
for _ in range(skip):
next(filtered, None)
return next(filtered, None)
def applied_events(self) -> List[Dict[Text, Any]]:
"""Returns all actions that should be applied - w/o reverted events."""
def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):
"""Removes events from `done_events` until the first
occurrence `event_type` is found which is also removed."""
# list gets modified - hence we need to copy events!
for e in reversed(done_events[:]):
del done_events[-1]
if e["event"] == event_type:
break
applied_events: List[Dict[Text, Any]] = []
for event in self.events:
event_type = event.get("event")
if event_type == "restart":
applied_events = []
elif event_type == "undo":
undo_till_previous("action", applied_events)
elif event_type == "rewind":
# Seeing a user uttered event automatically implies there was
# a listen event right before it, so we'll first rewind the
# user utterance, then get the action right before it (also removes
# the `action_listen` action right before it).
undo_till_previous("user", applied_events)
undo_till_previous("action", applied_events)
else:
applied_events.append(event)
return applied_events
def slots_to_validate(self) -> Dict[Text, Any]:
"""Get slots which were recently set.
This can e.g. be used to validate form slots after they were extracted.
Returns:
A mapping of extracted slot candidates and their values.
"""
slots: Dict[Text, Any] = {}
count: int = 0
for event in reversed(self.events):
# The `FormAction` in Rasa Open Source will append all slot candidates
# at the end of the tracker events.
if event["event"] == "slot":
count += 1
else:
# Stop as soon as there is another event type as this means that we
# checked all potential slot candidates.
break
for event in self.events[len(self.events) - count :]:
slots[event["name"]] = event["value"]
return slots
def add_slots(self, slots: List[EventType]) -> None:
"""Adds slots to the current tracker.
Args:
slots: `SlotSet` events.
"""
for event in slots:
if not event.get("event") == "slot":
continue
self.slots[event["name"]] = event["value"]
self.events.append(event)
def get_intent_of_latest_message(
self, skip_fallback_intent: bool = True
) -> Optional[Text]:
"""Retrieves the intent the last user message.
Args:
skip_fallback_intent: Optionally skip the nlu_fallback intent
and return the next.
Returns:
Intent of latest message if available.
"""
latest_message = self.latest_message
if not latest_message:
return None
intent_ranking = latest_message.get("intent_ranking")
if not intent_ranking:
return None
highest_ranking_intent = intent_ranking[0]
if (
highest_ranking_intent["name"] == NLU_FALLBACK_INTENT_NAME
and skip_fallback_intent
):
if len(intent_ranking) >= 2:
return intent_ranking[1]["name"]
else:
return None
else:
return highest_ranking_intent["name"]
run方法返回的是rasa_sdk.events.Event
类型是List[Dict[str, Any]]
用户输入一个信息,提取出Entities加进去,使用EntitiesAdded,用于将提取的实体添加到跟踪器状态
rasa core的event.py
class EntitiesAdded(SkipEventInMDStoryMixin):
"""Event that is used to add extracted entities to the tracker state."""
type_name = "entities"
def __init__(
self,
entities: List[Dict[Text, Any]],
timestamp: Optional[float] = None,
metadata: Optional[Dict[Text, Any]] = None,
) -> None:
"""Initializes event.
Args:
entities: Entities extracted from previous user message. This can either
be done by NLU components or end-to-end policy predictions.
timestamp: the timestamp
metadata: some optional metadata
"""
super().__init__(timestamp, metadata)
self.entities = entities
def __str__(self) -> Text:
"""Returns the string representation of the event."""
entity_str = [e[ENTITY_ATTRIBUTE_TYPE] for e in self.entities]
return f"{self.__class__.__name__}({entity_str})"
def __hash__(self) -> int:
"""Returns the hash value of the event."""
return hash(json.dumps(self.entities))
def __eq__(self, other: Any) -> bool:
"""Compares this event with another event."""
if not isinstance(other, EntitiesAdded):
return NotImplemented
return self.entities == other.entities
@classmethod
def _from_parameters(cls, parameters: Dict[Text, Any]) -> "EntitiesAdded":
return EntitiesAdded(
parameters.get(ENTITIES),
parameters.get("timestamp"),
parameters.get("metadata"),
)
def as_dict(self) -> Dict[Text, Any]:
"""Converts the event into a dict.
Returns:
A dict that represents this event.
"""
d = super().as_dict()
d.update({ENTITIES: self.entities})
return d
def apply_to(self, tracker: "DialogueStateTracker") -> None:
"""Applies event to current conversation state.
Args:
tracker: The current conversation state.
"""
if tracker.latest_action_name != ACTION_LISTEN_NAME:
# entities belong only to the last user message
# a user message always comes after action listen
return
for entity in self.entities:
if entity not in tracker.latest_message.entities:
tracker.latest_message.entities.append(entity)
rasa官网的例子
from typing import Text, Dict, Any, List
from rasa_sdk import Action
from rasa_sdk.events import SlotSet
class ActionCheckRestaurants(Action):
def name(self) -> Text:
return "action_check_restaurants"
def run(self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:
cuisine = tracker.get_slot('cuisine')
q = "select * from restaurants where cuisine='{0}' limit 1".format(cuisine)
result = db.query(q)
return [SlotSet("matches", result if result is not None else [])]
————————————————
rasa完整视频《
Rasa微服务Action自定义及Slot Validation详解
》https://appnvzljxc49685.h5.xiaoeknow.com