Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swarm: User-defined AfterWork string for agent selection using LLM #369

Open
wants to merge 12 commits into
base: swarmagenttoconversable
Choose a base branch
from
4 changes: 2 additions & 2 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
ON_CONDITION,
AfterWork,
AfterWorkOption,
ContextStr,
OnCondition,
SwarmResult,
UpdateCondition,
a_initiate_swarm_chat,
initiate_swarm_chat,
register_hand_off,
Expand All @@ -47,7 +47,7 @@
"SwarmResult",
"ON_CONDITION",
"OnCondition",
"UpdateCondition",
"ContextStr",
"AFTER_WORK",
"AfterWork",
"AfterWorkOption",
Expand Down
199 changes: 157 additions & 42 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,40 @@
from ..agent import Agent
from ..chat import ChatResult
from ..conversable_agent import __CONTEXT_VARIABLES_PARAM_NAME__, ConversableAgent
from ..groupchat import GroupChat, GroupChatManager
from ..groupchat import SELECT_SPEAKER_PROMPT_TEMPLATE, GroupChat, GroupChatManager
from ..user_proxy_agent import UserProxyAgent


@dataclass
class UpdateCondition:
"""Update the condition string before they reply
class ContextStr:
"""A string that requires context variable substitution.

Use the format method to substitute context variables into the string.

Args:
update_function: The string or function to update the condition string. Can be a string or a Callable.
If a string, it will be used as a template and substitute the context variables.
If a Callable, it should have the signature:
def my_update_function(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
template: The string to be substituted with context variables. It is expected that the string will contain {var} placeholders
and that string format will be able to replace all values.
"""

update_function: Union[Callable, str]
template: str

def __post_init__(self):
if isinstance(self.update_function, str):
assert self.update_function.strip(), " please provide a non-empty string or a callable"
# find all {var} in the string
vars = re.findall(r"\{(\w+)\}", self.update_function)
if len(vars) == 0:
warnings.warn("Update function string contains no variables. This is probably unintended.")

elif isinstance(self.update_function, Callable):
sig = signature(self.update_function)
if len(sig.parameters) != 2:
raise ValueError(
"Update function must accept two parameters of type ConversableAgent and List[Dict[str, Any]], respectively"
)
if sig.return_annotation != str:
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")
def __init__(self, template: str):
self.template = template

def format(self, context_variables: dict[str, Any]) -> str:
"""Substitute context variables into the string.

Args:
context_variables: The context variables to substitute into the string.
"""
return OpenAIWrapper.instantiate(
template=self.template,
context=context_variables,
allow_format_str_template=True,
)

def __str__(self) -> str:
return f"ContextStr, unformatted: {self.template}"


# Created tool executor's name
Expand All @@ -75,14 +74,32 @@ class AfterWork:
agent: The agent to hand off to or the after work option. Can be a ConversableAgent, a string name of a ConversableAgent, an AfterWorkOption, or a Callable.
The Callable signature is:
def my_after_work_func(last_speaker: ConversableAgent, messages: List[Dict[str, Any]], groupchat: GroupChat) -> Union[AfterWorkOption, ConversableAgent, str]:
next_agent_selection_msg: Optional[Union[str, Callable]]: Optional message to use for the agent selection (in internal group chat), only valid for when agent is AfterWorkOption.SWARM_MANAGER.
If a string, it will be used as a template and substitute the context variables.
If a Callable, it should have the signature:
def my_selection_message(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
"""

agent: Union[AfterWorkOption, ConversableAgent, str, Callable]
next_agent_selection_msg: Optional[Union[str, ContextStr, Callable]] = None

def __post_init__(self):
if isinstance(self.agent, str):
self.agent = AfterWorkOption(self.agent.upper())

# next_agent_selection_msg is only valid for when agent is AfterWorkOption.SWARM_MANAGER, but isn't mandatory
if self.next_agent_selection_msg is not None:

if not isinstance(self.next_agent_selection_msg, (str, ContextStr, Callable)):
raise ValueError("next_agent_selection_msg must be a string, ContextStr, or a Callable")

if self.agent != AfterWorkOption.SWARM_MANAGER:
warnings.warn(
"next_agent_selection_msg is only valid for agent=AfterWorkOption.SWARM_MANAGER. Ignoring the value.",
UserWarning,
)
self.next_agent_selection_msg = None


class AFTER_WORK(AfterWork):
"""Deprecated: Use AfterWork instead. This class will be removed in a future version (TBD)."""
Expand All @@ -104,13 +121,20 @@ class OnCondition:
target: The agent to hand off to or the nested chat configuration. Can be a ConversableAgent or a Dict.
If a Dict, it should follow the convention of the nested chat configuration, with the exception of a carryover configuration which is unique to Swarms.
Swarm Nested chat documentation: https://docs.ag2.ai/docs/topics/swarm#registering-handoffs-to-a-nested-chat
condition (str): The condition for transitioning to the target agent, evaluated by the LLM to determine whether to call the underlying function/tool which does the transition.
available (Union[Callable, str]): Optional condition to determine if this OnCondition is available. Can be a Callable or a string.
If a string, it will look up the value of the context variable with that name, which should be a bool.
condition (Union[str, ContextStr, Callable]): The condition for transitioning to the target agent, evaluated by the LLM.
If a string or Callable, no automatic context variable substitution occurs.
If a ContextStr, context variable substitution occurs.
The Callable signature is:
def my_condition_string(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
available (Union[Callable, str]): Optional condition to determine if this OnCondition is included for the LLM to evaluate. Can be a Callable or a string.
If a string, it will look up the value of the context variable with that name, which should be a bool, to determine whether it should include this condition.
The Callable signature is:
def my_available_func(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> bool

"""

target: Union[ConversableAgent, dict[str, Any]] = None
condition: Union[str, UpdateCondition] = ""
condition: Union[str, ContextStr, Callable] = ""
available: Optional[Union[Callable, str]] = None

def __post_init__(self):
Expand All @@ -124,7 +148,9 @@ def __post_init__(self):
if isinstance(self.condition, str):
assert self.condition.strip(), "'condition' must be a non-empty string"
else:
assert isinstance(self.condition, UpdateCondition), "'condition' must be a string or UpdateOnConditionStr"
assert isinstance(
self.condition, (ContextStr, Callable)
), "'condition' must be a string, ContextStr, or callable"

if self.available is not None:
assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string"
Expand Down Expand Up @@ -154,6 +180,7 @@ def _swarm_agent_str(self: ConversableAgent) -> str:
return f"Swarm agent --> {self.name}"

agent._swarm_after_work = None
agent._swarm_after_work_selection_msg = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
Expand Down Expand Up @@ -329,6 +356,58 @@ def _cleanup_temp_user_messages(chat_result: ChatResult) -> None:
del message["name"]


def _prepare_groupchat_auto_speaker(
groupchat: GroupChat,
last_swarm_agent: ConversableAgent,
after_work_next_agent_selection_msg: Optional[Union[str, ContextStr, Callable]],
) -> None:
"""Prepare the group chat for auto speaker selection, includes updating or restore the groupchat speaker selection message.

Tool Executor and Nested Chat agents will be removed from the available agents list.

Args:
groupchat (GroupChat): GroupChat instance.
last_swarm_agent (ConversableAgent): The last swarm agent for which the LLM config is used
after_work_next_agent_selection_msg (Union[str, ContextStr, Callable]): Optional message to use for the agent selection (in internal group chat).
if a string, it will be use the string a the prompt template, no context variable substitution however '{agentlist}' will be substituted for a list of agents.
if a ContextStr, it will substitute the agentlist first and then the context variables
if a Callable, it will not substitute the agentlist or context variables, signature:
def my_selection_message(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
"""

def substitute_agentlist(template: str) -> str:
# Run through group chat's string substitution first for {agentlist}
# We need to do this so that the next substitution doesn't fail with agentlist
# and we can remove the tool executor and nested chats from the available agents list
agent_list = [
agent
for agent in groupchat.agents
if agent.name != __TOOL_EXECUTOR_NAME__ and not agent.name.startswith("nested_chat_")
]

groupchat.select_speaker_prompt_template = template
return groupchat.select_speaker_prompt(agent_list)

if after_work_next_agent_selection_msg is None:
# If there's no selection message, restore the default and filter out the tool executor and nested chat agents
groupchat.select_speaker_prompt_template = substitute_agentlist(SELECT_SPEAKER_PROMPT_TEMPLATE)
elif isinstance(after_work_next_agent_selection_msg, str):
# No context variable substitution for string, but agentlist will be substituted
groupchat.select_speaker_prompt_template = substitute_agentlist(after_work_next_agent_selection_msg)
elif isinstance(after_work_next_agent_selection_msg, ContextStr):
# Replace the agentlist in the string first, putting it into a new ContextStr
agent_list_replaced_string = ContextStr(substitute_agentlist(after_work_next_agent_selection_msg.template))
Comment on lines +382 to +399
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we let the user to only pass in a str or a callable?
We can always convert the str to a ContextStr class, replace agent_list and context variables.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. You raise a good point and Chi and I had a chat about this as well prior. The thought behind making it an explicit choice to use ContextStr or a normal, non-substituted, string is so that the developer has complete control over the choice and they're not unsure of what will happen to the value.

@sonichi if you have anything further to add to this consideration?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the ambiguity happens when the str contains "{}", and whether it means a format string or not is ambiguous. That's my understanding of why ContextStr is introduced. Is that right? Shall we add the consideration into docstr?

Copy link
Collaborator Author

@marklysze marklysze Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right, by using "{}" the user doesn't have to be concerned about unexpected replacements (except {agentlist], which is default Group Chat behaviour)

In terms of adding to docstr, do you mean the ContextStr docstring?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiranwu0, did you have any further thoughts on this? Are you okay with making it an explicit choice for the developer to choose ContextStr when they want the substitutions and without when they don't?


# Then replace the context variables
groupchat.select_speaker_prompt_template = agent_list_replaced_string.format(
last_swarm_agent._context_variables
)
elif isinstance(after_work_next_agent_selection_msg, Callable):
groupchat.select_speaker_prompt_template = substitute_agentlist(
after_work_next_agent_selection_msg(last_swarm_agent, groupchat.messages)
)


def _determine_next_agent(
last_speaker: ConversableAgent,
groupchat: GroupChat,
Expand Down Expand Up @@ -386,11 +465,14 @@ def _determine_next_agent(
if (user_agent and last_speaker == user_agent) or groupchat.messages[-1]["role"] == "tool":
return last_swarm_speaker

after_work_next_agent_selection_msg = None

# Resolve after_work condition (agent-level overrides global)
after_work_condition = (
last_swarm_speaker._swarm_after_work if last_swarm_speaker._swarm_after_work is not None else swarm_after_work
)
if isinstance(after_work_condition, AfterWork):
after_work_next_agent_selection_msg = after_work_condition.next_agent_selection_msg
after_work_condition = after_work_condition.agent

# Evaluate callable after_work
Expand All @@ -412,6 +494,7 @@ def _determine_next_agent(
elif after_work_condition == AfterWorkOption.STAY:
return last_speaker
elif after_work_condition == AfterWorkOption.SWARM_MANAGER:
_prepare_groupchat_auto_speaker(groupchat, last_swarm_speaker, after_work_next_agent_selection_msg)
return "auto"
else:
raise ValueError("Invalid After Work condition or return value from callable")
Expand Down Expand Up @@ -457,11 +540,44 @@ def swarm_transition(last_speaker: ConversableAgent, groupchat: GroupChat) -> Op
return swarm_transition


def _create_swarm_manager(
groupchat: GroupChat, swarm_manager_args: dict[str, Any], agents: list[ConversableAgent]
) -> GroupChatManager:
"""Create a GroupChatManager for the swarm chat utilising any arguments passed in and ensure an LLM Config exists if needed

Args:
groupchat (GroupChat): Swarm groupchat.
swarm_manager_args (dict[str, Any]): Swarm manager arguments to create the GroupChatManager.

Returns:
GroupChatManager: GroupChatManager instance.
"""
manager_args = (swarm_manager_args or {}).copy()
if "groupchat" in manager_args:
raise ValueError("'groupchat' cannot be specified in swarm_manager_args as it is set by initiate_swarm_chat")
manager = GroupChatManager(groupchat, **manager_args)

# Ensure that our manager has an LLM Config if we have any AfterWorkOption.SWARM_MANAGER after works
if manager.llm_config is False:
for agent in agents:
if (
agent._swarm_after_work
and isinstance(agent._swarm_after_work.agent, AfterWorkOption)
and agent._swarm_after_work.agent == AfterWorkOption.SWARM_MANAGER
):
raise ValueError(
"The swarm manager doesn't have an LLM Config and it is required for AfterWorkOption.SWARM_MANAGER. Use the swarm_manager_args to specify the LLM Config for the swarm manager."
)

return manager


def initiate_swarm_chat(
initial_agent: ConversableAgent,
messages: Union[list[dict[str, Any]], str],
agents: list[ConversableAgent],
user_agent: Optional[UserProxyAgent] = None,
swarm_manager_args: Optional[dict[str, Any]] = None,
max_rounds: int = 20,
context_variables: Optional[dict[str, Any]] = None,
after_work: Optional[Union[AfterWorkOption, Callable]] = AfterWork(AfterWorkOption.TERMINATE),
Expand All @@ -473,6 +589,7 @@ def initiate_swarm_chat(
messages: Initial message(s).
agents: List of swarm agents.
user_agent: Optional user proxy agent for falling back to.
swarm_manager_args: Optional group chat manager arguments used to establish the swarm's groupchat manager, required when AfterWorkOption.SWARM_MANAGER is used.
max_rounds: Maximum number of conversation rounds.
context_variables: Starting context variables.
after_work: Method to handle conversation continuation when an agent doesn't select the next agent. If no agent is selected and no tool calls are output, we will use this method to determine the next agent.
Expand Down Expand Up @@ -513,7 +630,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: List[Dict[st
speaker_selection_method=swarm_transition,
)

manager = GroupChatManager(groupchat)
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)

# Point all ConversableAgent's context variables to this function's context_variables
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
Expand Down Expand Up @@ -541,6 +658,7 @@ async def a_initiate_swarm_chat(
messages: Union[list[dict[str, Any]], str],
agents: list[ConversableAgent],
user_agent: Optional[UserProxyAgent] = None,
swarm_manager_args: Optional[dict[str, Any]] = None,
max_rounds: int = 20,
context_variables: Optional[dict[str, Any]] = None,
after_work: Optional[Union[AfterWorkOption, Callable]] = AfterWork(AfterWorkOption.TERMINATE),
Expand All @@ -552,6 +670,7 @@ async def a_initiate_swarm_chat(
messages: Initial message(s).
agents: List of swarm agents.
user_agent: Optional user proxy agent for falling back to.
swarm_manager_args: Optional group chat manager arguments used to establish the swarm's groupchat manager, required when AfterWorkOption.SWARM_MANAGER is used.
max_rounds: Maximum number of conversation rounds.
context_variables: Starting context variables.
after_work: Method to handle conversation continuation when an agent doesn't select the next agent. If no agent is selected and no tool calls are output, we will use this method to determine the next agent.
Expand Down Expand Up @@ -592,7 +711,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: List[Dict[st
speaker_selection_method=swarm_transition,
)

manager = GroupChatManager(groupchat)
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)

# Point all ConversableAgent's context variables to this function's context_variables
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
Expand Down Expand Up @@ -680,6 +799,7 @@ def transfer_to_agent_name() -> ConversableAgent:
transit.agent, (AfterWorkOption, ConversableAgent, str, Callable)
), "Invalid After Work value"
agent._swarm_after_work = transit
agent._swarm_after_work_selection_msg = transit.next_agent_selection_msg
elif isinstance(transit, OnCondition):

if isinstance(transit.target, ConversableAgent):
Expand Down Expand Up @@ -737,15 +857,10 @@ def _update_conditional_functions(agent: ConversableAgent, messages: Optional[li
# then add the function if it is available, so that the function signature is updated
if is_available:
condition = on_condition.condition
if isinstance(condition, UpdateCondition):
if isinstance(condition.update_function, str):
condition = OpenAIWrapper.instantiate(
template=condition.update_function,
context=agent._context_variables,
allow_format_str_template=True,
)
else:
condition = condition.update_function(agent, messages)
if isinstance(condition, ContextStr):
condition = condition.format(context_variables=agent._context_variables)
elif isinstance(condition, Callable):
condition = condition(agent, messages)
agent._add_single_function(func, func_name, condition)


Expand Down
Loading
Loading