File size: 12,253 Bytes
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36633d8
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f2c01e
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import os
import sysconfig
from datetime import datetime
from typing import Any, Dict, List, Union, Optional

import geocoder
from openai import AsyncOpenAI, OpenAI
from pydantic import Field

from omagent_core.utils.registry import registry
from omagent_core.models.llms.base import BaseLLM
from omagent_core.models.llms.schemas import Content, Message

BASIC_SYS_PROMPT = """You are an intelligent agent that can help in many regions. 
Flowing are some basic information about your working environment, please try your best to answer the questions based on them if needed. 
Be confident about these information and don't let others feel these information are presets.
Be concise.
---BASIC INFORMATION---
Current Datetime: {}
Region: {}
Operating System: {}"""


@registry.register_llm()
class OpenaiGPTLLM(BaseLLM):
    model_id: str = Field(
        default=os.getenv("MODEL_ID", "gpt-4o"), description="The model id of openai"
    )
    vision: bool = Field(default=False, description="Whether the model supports vision")
    endpoint: str = Field(
        default=os.getenv("ENDPOINT", "https://api.openai.com/v1"),
        description="The endpoint of LLM service",
    )
    api_key: str = Field(
        default=os.getenv("API_KEY"), description="The api key of openai"
    )
    temperature: float = Field(default=1.0, description="The temperature of LLM")
    top_p: float = Field(
        default=1.0,
        description="The top p of LLM, controls diversity of responses. Should not be used together with temperature - use either temperature or top_p but not both",
    )
    stream: bool = Field(default=False, description="Whether to stream the response")
    max_tokens: int = Field(default=2048, description="The max tokens of LLM")
    use_default_sys_prompt: bool = Field(
        default=True, description="Whether to use the default system prompt"
    )
    response_format: Optional[Union[dict, str]] = Field(default='text', description="The response format of openai")
    n: int = Field(default=1, description="The number of responses to generate")
    frequency_penalty: float = Field(
        default=0, description="The frequency penalty of LLM, -2 to 2"
    )
    logit_bias: Optional[dict] = Field(
        default=None, description="The logit bias of LLM"
    )
    logprobs: bool = Field(default=False, description="The logprobs of LLM")
    top_logprobs: Optional[int] = Field(
        default=None,
        description="The top logprobs of LLM, logprobs must be set to true if this parameter is used",
    )
    stop: Union[str, List[str], None] = Field(
        default='',
        description="Specifies stop sequences that will halt text generation, can be string or list of strings",
    )
    stream_options: Optional[dict] = Field(
        default=None,
        description="Configuration options for streaming responses when stream=True",
    )
    tools: Optional[List[dict]] = Field(
        default=None,
        description="A list of function tools (max 128) that the model can call, each requiring a type, name and optional description/parameters defined in JSON Schema format.",
    )
    tool_choice: Optional[str] = Field(
        default="none",
        description="Controls which tool (if any) is called by the model: 'none', 'auto', 'required', or a specific tool.",
    )

    class Config:
        """Configuration for this pydantic object."""

        protected_namespaces = ()
        extra = "allow"

    def check_response_format(self) -> Optional[dict]:
        if isinstance(self.response_format, str):
            if self.response_format == "text":
                self.response_format = {"type": "text"}
            elif self.response_format == "json_object":
                self.response_format = {"type": "json_object"}
        elif isinstance(self.response_format, dict):
            for key, value in self.response_format.items():
                if key not in ["type", "json_schema"]:
                    raise ValueError(f"Invalid response format key: {key}")
                if key == "type":
                    if value not in ["text", "json_object"]:
                        raise ValueError(f"Invalid response format value: {value}")
                elif key == "json_schema":
                    if not isinstance(value, dict):
                        raise ValueError(f"Invalid response format value: {value}")
        else:
            raise ValueError(f"Invalid response format: {self.response_format}")

    def model_post_init(self, __context: Any) -> None:
        self.check_response_format()
        self.client = OpenAI(api_key=self.api_key, base_url=self.endpoint)
        self.aclient = AsyncOpenAI(api_key=self.api_key, base_url=self.endpoint)

    def _call(self, records: List[Message], **kwargs) -> Dict:
        if self.api_key is None or self.api_key == "":
            raise ValueError("api_key is required")

        messages = self._msg2req(records)
        print(f'messages: {messages}')
        if self.vision:
            res = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,
                temperature=kwargs.get("temperature", self.temperature),
                max_tokens=kwargs.get("max_tokens", self.max_tokens),
                stream=kwargs.get("stream", self.stream),
                n=kwargs.get("n", self.n),
                top_p=kwargs.get("top_p", self.top_p),
                frequency_penalty=kwargs.get(
                    "frequency_penalty", self.frequency_penalty
                ),
                logit_bias=kwargs.get("logit_bias", self.logit_bias),
                logprobs=kwargs.get("logprobs", self.logprobs),
                top_logprobs=kwargs.get("top_logprobs", self.top_logprobs),
                stop=kwargs.get("stop", self.stop),
                stream_options=kwargs.get("stream_options", self.stream_options),
            )
        else:
            res = self.client.chat.completions.create(
                model=self.model_id,
                messages=messages,
                temperature=kwargs.get("temperature", self.temperature),
                max_tokens=kwargs.get("max_tokens", self.max_tokens),
                response_format=kwargs.get("response_format", self.response_format),
                tools=kwargs.get("tools", None),
                tool_choice=kwargs.get("tool_choice", None),
                stream=kwargs.get("stream", self.stream),
                n=kwargs.get("n", self.n),
                top_p=kwargs.get("top_p", self.top_p),
                frequency_penalty=kwargs.get(
                    "frequency_penalty", self.frequency_penalty
                ),
                logit_bias=kwargs.get("logit_bias", self.logit_bias),
                logprobs=kwargs.get("logprobs", self.logprobs),
                top_logprobs=kwargs.get("top_logprobs", self.top_logprobs),
                stop=kwargs.get("stop", self.stop),
                stream_options=kwargs.get("stream_options", self.stream_options),
            )

        if kwargs.get("stream", self.stream):
            return res
        else:
            return res.model_dump()

    async def _acall(self, records: List[Message], **kwargs) -> Dict:
        if self.api_key is None or self.api_key == "":
            raise ValueError("api_key is required")

        messages = self._msg2req(records)

        if self.vision:
            res = await self.aclient.chat.completions.create(
                model=self.model_id,
                messages=messages,
                temperature=kwargs.get("temperature", self.temperature),
                max_tokens=kwargs.get("max_tokens", self.max_tokens),
                n=kwargs.get("n", self.n),
                top_p=kwargs.get("top_p", self.top_p),
                frequency_penalty=kwargs.get(
                    "frequency_penalty", self.frequency_penalty
                ),
                logit_bias=kwargs.get("logit_bias", self.logit_bias),
                logprobs=kwargs.get("logprobs", self.logprobs),
                top_logprobs=kwargs.get("top_logprobs", self.top_logprobs),
                stop=kwargs.get("stop", self.stop),
                stream_options=kwargs.get("stream_options", self.stream_options),
            )
        else:
            res = await self.aclient.chat.completions.create(
                model=self.model_id,
                messages=messages,
                temperature=kwargs.get("temperature", self.temperature),
                max_tokens=kwargs.get("max_tokens", self.max_tokens),
                response_format=kwargs.get("response_format", self.response_format),
                tools=kwargs.get("tools", None),
                n=kwargs.get("n", self.n),
                top_p=kwargs.get("top_p", self.top_p),
                frequency_penalty=kwargs.get(
                    "frequency_penalty", self.frequency_penalty
                ),
                logit_bias=kwargs.get("logit_bias", self.logit_bias),
                logprobs=kwargs.get("logprobs", self.logprobs),
                top_logprobs=kwargs.get("top_logprobs", self.top_logprobs),
                stop=kwargs.get("stop", self.stop),
                stream_options=kwargs.get("stream_options", self.stream_options),
            )
        return res.model_dump()

    def _msg2req(self, records: List[Message]) -> dict:
        def get_content(msg: List[Content] | Content) -> List[dict] | str:
            if isinstance(msg, list):
                return [c.model_dump(exclude_none=True) for c in msg]
            elif isinstance(msg, Content) and msg.type == "text":
                return msg.text
            elif isinstance(msg, Content) and msg.type == "image_url":
                return [msg.model_dump(exclude_none=True)]
            else:
                print(f'msg: {msg}')
                raise ValueError("Invalid message type")

        messages = [
            {"role": message.role, "content": get_content(message.content)}
            for message in records
        ]
        if self.vision:
            processed_messages = []
            for message in messages:
                if message["role"] == "user":
                    if isinstance(message["content"], str):
                        message["content"] = [
                            {"type": "text", "text": message["content"]}
                        ]
            merged_dict = {}
            for message in messages:
                if message["role"] == "user":
                    merged_dict["role"] = message["role"]
                    if "content" in merged_dict:
                        merged_dict["content"] += message["content"]
                    else:
                        merged_dict["content"] = message["content"]
                else:
                    processed_messages.append(message)
            processed_messages.append(merged_dict)
            messages = processed_messages
        if self.use_default_sys_prompt:
            messages = [self._generate_default_sys_prompt()] + messages
        return messages

    def _generate_default_sys_prompt(self) -> Dict:
        loc = self._get_location()
        os = self._get_linux_distribution()
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        promt_str = BASIC_SYS_PROMPT.format(current_time, loc, os)
        return {"role": "system", "content": promt_str}

    def _get_linux_distribution(self) -> str:
        platform = sysconfig.get_platform()
        if "linux" in platform:
            if os.path.exists("/etc/lsb-release"):
                with open("/etc/lsb-release", "r") as f:
                    for line in f:
                        if line.startswith("DISTRIB_DESCRIPTION="):
                            return line.split("=")[1].strip()
            elif os.path.exists("/etc/os-release"):
                with open("/etc/os-release", "r") as f:
                    for line in f:
                        if line.startswith("PRETTY_NAME="):
                            return line.split("=")[1].strip()
        return platform

    def _get_location(self) -> str:
        g = geocoder.ip("me")
        if g.ok:
            return g.city + "," + g.country
        else:
            return "unknown"