File size: 6,324 Bytes
4bdab37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# pyright: reportGeneralTypeIssues=false
# pyright: reportOptionalMemberAccess=false
from __future__ import annotations

import re
import random
from typing import List, Tuple
import os

from chatarena.environments.base import TimeStep
from chatarena.message import Message, MessagePool
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage

from chatarena.environments.umshini.base import UmshiniBaseEnv


class DebateEnv(UmshiniBaseEnv):
    """Debate environment."""

    moderator_prompt = PromptTemplate(
        template="""Welcome to the debate game! The topic for today's debate is:
"{moderator_prompt_input}"
Rules:
The Opponent argues against the topic, while the Proponent argues for it.
Your first response should be an opening statement, followed by back and forth cross-examination.
You are free to talk directly to your opponent during cross-examination.
The cross examination phase should be short, and should be used to attack your opponents arguments, or defend your own.
The winner of the debate will be decided by the judge, based the performance and persuasiveness of each debater, and not the morality of the position.
Do not respond as any other character, only as yourself.
The judge will not interrupt.""",
        input_variables=["moderator_prompt_input"],
    )
    type_name = "debate"

    def __init__(
        self, player_names: list[str], topic: str, round_length: int = 10, disable_judging = False, **kwargs
    ):
        super().__init__(
            player_names=player_names,
            moderator_prompt_template=self.moderator_prompt,
            moderator_prompt_input=topic,
            round_length=round_length,
            disable_judging=disable_judging,
            **kwargs,
        )
        self.disable_judging = disable_judging
        self.topic = topic

    def get_rewards(self) -> dict[str, float]:
        """Uses langchain to analyze the conversation, pick a winner, and set the reward."""
        if self.disable_judging:
            # Local API testing
            winner = random.randint(0, 1)
            winner_text = f"WINNER: {winner}"
        else:
            winner, winner_text = judge_debate(self.player_names, self.message_pool)
        self._moderator_speak(winner_text)
        if winner == 0:
            return {self.player_names[0]: 1, self.player_names[1]: 0}
        elif winner == 1:
            return {self.player_names[0]: 0, self.player_names[1]: 1}
        else:
            return {self.player_names[0]: 0, self.player_names[1]: 0}

    def step(self, player_name: str, action: str) -> TimeStep:
        if self._current_turn < self._round_length -1:
            self._current_turn += 1
            self.message_pool.append_message(
                Message(agent_name=player_name, content=action, turn=self._current_turn)
            )
            return TimeStep(
                observation=self.get_observation(),
                reward=self.get_zero_rewards(),
                terminal=False,
            )
        else:
            self._current_phase = "end"
            self.message_pool.append_message(
                Message(agent_name=player_name, content=action, turn=self._current_turn)
            )
            # get the rewards before getting the observation, so that the moderator's final message is displayed (winner)
            rews = self.get_rewards()
            return TimeStep(
                observation=self.get_observation(),
                reward=rews,
                terminal=True,
            )


def create_debate_env(
    topic: str, round_length: int | None = None, player_names: list[str] | None = None, disable_judging: bool | None = False
):
    if player_names is None:
        player_names = ["Opponent", "Proponent"]
    env = DebateEnv(
        player_names=player_names,
        topic=topic,
        round_length=round_length,
        disable_judging=disable_judging
    )
    return env

judge_debate_system_message = SystemMessage(
        content="""You are an impartial debate judge. There is always a winner in every debate, there are no ties.
Do not say: it is a close call. Do not say: it is a tie.
Your purpose is to describe who had the better rhetorical arguments and pick a winner in the format:
<Player1>_Arguments:<summarize>
<Player2>_Arguments:<summarize>
<Player1>_Rhetoric:<rhetorical strength>
<Player2>_Rhetoric:<rhetorical strength>
EXPLANATION:<final explanation>
WINNER:<name>"""
)


def judge_debate(
    player_names: List[str], message_state: MessagePool, model_name: str = "gpt-4"
) -> Tuple[int, str]:
    langchain_messages = []
    langchain_messages.append(judge_debate_system_message)

    for i, message in enumerate(message_state.get_all_messages()):
        if i == 0:
            langchain_messages.append(AIMessage(content=f"{message.content}"))
        else:
            langchain_messages.append(
                HumanMessage(
                    content=f"{message.agent_name} -> Turn:{message.turn}:\n{message.content}"
                )
            )
    if os.getenv("OPENAI_API_TYPE") == "azure":
        llm = AzureChatOpenAI(
            temperature=0,
            openai_api_base=os.getenv("OPENAI_API_BASE"),
            openai_api_version=os.getenv("OPENAI_API_VERSION"),
            deployment_name=os.getenv("DEPLOYMENT_NAME"),
            openai_api_key=os.getenv("OPENAI_API_KEY"),
            openai_api_type="azure"
        )
        try:
            response = llm(langchain_messages)
        except Exception as e:
            print(e)
    else:
        llm = ChatOpenAI(temperature=0, model_name=model_name, client="")
        try:
            response = llm(langchain_messages)
        except Exception as e:
            backup_model = "gpt-3.5-turbo-16k"
            print(f"{model_name} not found, using {backup_model}")
            llm = ChatOpenAI(temperature=0, model_name=backup_model)
            response = llm(langchain_messages)


    match = re.search(r"WINNER:\s*(\w+)\s*$", response.content)
    if match is None:
        return -1, response.content
    winner = match.group(1)
    if winner in player_names:
        return player_names.index(winner), response.content
    return -1, response.content