Spaces:
Build error
Build error
File size: 3,512 Bytes
01523b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from __future__ import annotations
import random
from typing import TYPE_CHECKING, Any, List, Union
from . import visibility_registry as VisibilityRegistry
from .base import BaseVisibility
if TYPE_CHECKING:
from agentverse.environments import BaseEnvironment
@VisibilityRegistry.register("classroom")
class ClassroomVisibility(BaseVisibility):
"""
Visibility function for classroom, supports group discussion.
Args:
student_per_group:
The number of students per group.
num_discussion_turn:
The number of turns for group discussion.
grouping:
The grouping information. If it is a string, then it should be a
grouping method, options are ["random", "sequential"]. If it is a
list of list of int, then it should be the grouping information.
"""
grouping: Union[str, List[List[int]]]
student_per_group: int = 4
num_discussion_turn: int = 5
current_turn: int = 0
def update_visible_agents(self, environment: BaseEnvironment):
# We turn on grouping mode when the professor launches a group discussion
if len(environment.last_messages) == 1 and environment.last_messages[
0
].content.startswith("[GroupDiscuss]"):
environment.rule_params["is_grouped"] = True
# We randomly group the students
environment.rule_params["groups"] = self.group_students(environment)
# Update the receiver for each agent
self.update_receiver(environment)
else:
# If now in grouping mode, then we check if the group discussion is over
if environment.rule_params.get("is_grouped", False):
self.current_turn += 1
if self.current_turn >= self.num_discussion_turn:
self.reset()
environment.rule_params["is_grouped"] = False
environment.rule_params["is_grouped_ended"] = True
self.update_receiver(environment, reset=True)
def group_students(self, environment: BaseEnvironment) -> List[List[int]]:
if isinstance(self.grouping, str):
student_index = list(range(1, len(environment.agents)))
result = []
if self.grouping == "random":
random.shuffle(student_index)
for i in range(0, len(student_index), self.student_per_group):
result.append(student_index[i : i + self.student_per_group])
elif self.grouping == "sequential":
for i in range(0, len(student_index), self.student_per_group):
result.append(student_index[i : i + self.student_per_group])
else:
raise ValueError(f"Unsupported grouping method {self.grouping}")
return result
else:
# If the grouping information is provided, then we use it directly
return self.grouping
def update_receiver(self, environment: BaseEnvironment, reset=False):
if reset:
for agent in environment.agents:
agent.set_receiver(set({"all"}))
else:
groups = environment.rule_params["groups"]
for group in groups:
group_name = set({environment.agents[i].name for i in group})
for agent_id in group:
environment.agents[agent_id].set_receiver(group_name)
def reset(self):
self.current_turn = 0
|