Spaces:
Build error
Build error
from __future__ import annotations | |
import random | |
from typing import TYPE_CHECKING, Any, List, Union | |
from . import visibility_registry as VisibilityRegistry | |
from .base import BaseVisibility | |
if TYPE_CHECKING: | |
from agentverse.environments import BaseEnvironment | |
class ClassroomVisibility(BaseVisibility): | |
""" | |
Visibility function for classroom, supports group discussion. | |
Args: | |
student_per_group: | |
The number of students per group. | |
num_discussion_turn: | |
The number of turns for group discussion. | |
grouping: | |
The grouping information. If it is a string, then it should be a | |
grouping method, options are ["random", "sequential"]. If it is a | |
list of list of int, then it should be the grouping information. | |
""" | |
grouping: Union[str, List[List[int]]] | |
student_per_group: int = 4 | |
num_discussion_turn: int = 5 | |
current_turn: int = 0 | |
def update_visible_agents(self, environment: BaseEnvironment): | |
# We turn on grouping mode when the professor launches a group discussion | |
if len(environment.last_messages) == 1 and environment.last_messages[ | |
0 | |
].content.startswith("[GroupDiscuss]"): | |
environment.rule_params["is_grouped"] = True | |
# We randomly group the students | |
environment.rule_params["groups"] = self.group_students(environment) | |
# Update the receiver for each agent | |
self.update_receiver(environment) | |
else: | |
# If now in grouping mode, then we check if the group discussion is over | |
if environment.rule_params.get("is_grouped", False): | |
self.current_turn += 1 | |
if self.current_turn >= self.num_discussion_turn: | |
self.reset() | |
environment.rule_params["is_grouped"] = False | |
environment.rule_params["is_grouped_ended"] = True | |
self.update_receiver(environment, reset=True) | |
def group_students(self, environment: BaseEnvironment) -> List[List[int]]: | |
if isinstance(self.grouping, str): | |
student_index = list(range(1, len(environment.agents))) | |
result = [] | |
if self.grouping == "random": | |
random.shuffle(student_index) | |
for i in range(0, len(student_index), self.student_per_group): | |
result.append(student_index[i : i + self.student_per_group]) | |
elif self.grouping == "sequential": | |
for i in range(0, len(student_index), self.student_per_group): | |
result.append(student_index[i : i + self.student_per_group]) | |
else: | |
raise ValueError(f"Unsupported grouping method {self.grouping}") | |
return result | |
else: | |
# If the grouping information is provided, then we use it directly | |
return self.grouping | |
def update_receiver(self, environment: BaseEnvironment, reset=False): | |
if reset: | |
for agent in environment.agents: | |
agent.set_receiver(set({"all"})) | |
else: | |
groups = environment.rule_params["groups"] | |
for group in groups: | |
group_name = set({environment.agents[i].name for i in group}) | |
for agent_id in group: | |
environment.agents[agent_id].set_receiver(group_name) | |
def reset(self): | |
self.current_turn = 0 | |