from __future__ import annotations import logging import re from typing import TYPE_CHECKING, Any, List, Optional from . import order_registry as OrderRegistry from .base import BaseOrder if TYPE_CHECKING: from agentverse.environments import BaseEnvironment @OrderRegistry.register("classroom") class ClassroomOrder(BaseOrder): """The order for a classroom discussion The agents speak in the following order: 1. The professor speaks first 2. Then the professor can continue to speak, and the students can raise hands 3. The professor can call on a student, then the student can speak or ask a question 4. In the group discussion, the students in the group can speak in turn """ def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]: # `is_grouped_ended`: whether the group discussion just ended # `is_grouped`: whether it is currently in a group discussion if environment.rule_params.get("is_grouped_ended", False): return [0] if environment.rule_params.get("is_grouped", False): return self.get_next_agent_idx_grouped(environment) else: return self.get_next_agent_idx_ungrouped(environment) def get_next_agent_idx_ungrouped(self, environment: BaseEnvironment) -> List[int]: if len(environment.last_messages) == 0: # If the class just begins or no one speaks in the last turn, we let only the professor speak return [0] elif len(environment.last_messages) == 1: message = environment.last_messages[0] sender = message.sender content = message.content if sender.startswith("Professor"): if content.startswith("[CallOn]"): # 1. professor calls on someone, then the student should speak result = re.search(r"\[CallOn\] Yes, ([sS]tudent )?(\w+)", content) if result is not None: name_to_id = { agent.name[len("Student ") :]: i for i, agent in enumerate(environment.agents) } return [name_to_id[result.group(2)]] else: # 2. professor normally speaks, then anyone can act return list(range(len(environment.agents))) elif sender.startswith("Student"): # 3. student ask question after being called on, or # 4. only one student raises hand, and the professor happens to listen # 5. the group discussion is just over, and there happens to be only a student speaking in the last turn return [0] else: # If len(last_messages) > 1, then # 1. there must be at least one student raises hand or speaks. # 2. the group discussion is just over. return [0] assert ( False ), f"Should not reach here, last_messages: {environment.last_messages}" def get_next_agent_idx_grouped(self, environment: BaseEnvironment) -> List[int]: # Get the grouping information # groups: A list of list of agent ids, the i-th list contains # the agent ids in the i-th group # group_speaker_mapping: A mapping from group id to the id of # the speaker in the group # `groups` should be set in the corresponding `visibility`, # and `group_speaker_mapping` should be maintained here. if "groups" not in environment.rule_params: logging.warning( "The environment is grouped, but the grouping information is not provided." ) groups = environment.rule_params.get( "groups", [list(range(len(environment.agents)))] ) group_speaker_mapping = environment.rule_params.get( "group_speaker_mapping", {i: 0 for i in range(len(groups))} ) # For grouped environment, we let the students speak in turn within each group next_agent_idx = [] for group_id in range(len(groups)): speaker_index = group_speaker_mapping[group_id] speaker = groups[group_id][speaker_index] next_agent_idx.append(speaker) # Maintain the `group_speaker_mapping` for k, v in group_speaker_mapping.items(): group_speaker_mapping[k] = (v + 1) % len(groups[k]) environment.rule_params["group_speaker_mapping"] = group_speaker_mapping return next_agent_idx