File size: 4,088 Bytes
4bdab37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List
import os
from tenacity import retry, stop_after_attempt, wait_random_exponential

from .base import IntelligenceBackend
from ..message import Message

# Try to import the cohere package and check whether the API key is set
try:
    import cohere
except ImportError:
    is_cohere_available = False
else:
    if os.environ.get('COHEREAI_API_KEY') is None:
        is_cohere_available = False
    else:
        is_cohere_available = True

# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
DEFAULT_TEMPERATURE = 0.8
DEFAULT_MAX_TOKENS = 200
DEFAULT_MODEL = "command-xlarge"


class CohereAIChat(IntelligenceBackend):
    """
    Interface to the Cohere API
    """
    stateful = True
    type_name = "cohere-chat"

    def __init__(self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS,
                 model: str = DEFAULT_MODEL, **kwargs):
        super().__init__(temperature=temperature, max_tokens=max_tokens, model=model, **kwargs)

        self.temperature = temperature
        self.max_tokens = max_tokens
        self.model = model

        assert is_cohere_available, "Cohere package is not installed or the API key is not set"
        self.client = cohere.Client(os.environ.get('COHEREAI_API_KEY'))

        # Stateful variables
        self.session_id = None  # The session id for the last conversation
        self.last_msg_hash = None  # The hash of the last message of the last conversation

    def reset(self):
        self.session_id = None
        self.last_msg_hash = None

    @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
    def _get_response(self, new_message: str, persona_prompt: str):
        response = self.client.chat(
            new_message,
            persona_prompt=persona_prompt,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            session_id=self.session_id
        )

        self.session_id = response.session_id  # Update the session id
        return response.reply

    def query(self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None,
              request_msg: Message = None, *args, **kwargs) -> str:
        """
        format the input and call the Cohere API
        args:
            agent_name: the name of the agent
            role_desc: the description of the role of the agent
            env_desc: the description of the environment
            history_messages: the history of the conversation, or the observation for the agent
            request_msg: the request for the CohereAI
        """
        # Find the index of the last message of the last conversation
        new_message_start_idx = 0
        if self.last_msg_hash is not None:
            for i, message in enumerate(history_messages):
                if message.msg_hash == self.last_msg_hash:
                    new_message_start_idx = i + 1
                    break

        new_messages = history_messages[new_message_start_idx:]
        assert len(new_messages) > 0, "No new messages found (this should not happen)"

        new_conversations = []
        for message in new_messages:
            if message.agent_name != agent_name:
                # Since there are more than one player, we need to distinguish between the players
                new_conversations.append(f"[{message.agent_name}]: {message.content}")

        if request_msg:
            new_conversations.append(f"[{request_msg.agent_name}]: {request_msg.content}")

        # Concatenate all new messages into one message because the Cohere API only accepts one message
        new_message = "\n".join(new_conversations)
        persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"

        response = self._get_response(new_message, persona_prompt)

        # Only update the last message hash if the API call is successful
        self.last_msg_hash = new_messages[-1].msg_hash

        return response