Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
@Time : 2023/5/5 23:08 | |
@Author : alexanderwu | |
@File : openai.py | |
@From : https://github.com/geekan/MetaGPT/blob/main/metagpt/provider/openai_api.py | |
""" | |
import asyncio | |
import time | |
from functools import wraps | |
from typing import NamedTuple | |
import openai | |
import litellm | |
from autoagents.system.config import CONFIG | |
from autoagents.system.logs import logger | |
from autoagents.system.provider.base_gpt_api import BaseGPTAPI | |
from autoagents.system.utils.singleton import Singleton | |
from autoagents.system.utils.token_counter import ( | |
TOKEN_COSTS, | |
count_message_tokens, | |
count_string_tokens, | |
) | |
def retry(max_retries): | |
def decorator(f): | |
async def wrapper(*args, **kwargs): | |
for i in range(max_retries): | |
try: | |
return await f(*args, **kwargs) | |
except Exception: | |
if i == max_retries - 1: | |
raise | |
await asyncio.sleep(2 ** i) | |
return wrapper | |
return decorator | |
class RateLimiter: | |
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed""" | |
def __init__(self, rpm): | |
self.last_call_time = 0 | |
self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later | |
self.rpm = rpm | |
def split_batches(self, batch): | |
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)] | |
async def wait_if_needed(self, num_requests): | |
current_time = time.time() | |
elapsed_time = current_time - self.last_call_time | |
if elapsed_time < self.interval * num_requests: | |
remaining_time = self.interval * num_requests - elapsed_time | |
logger.info(f"sleep {remaining_time}") | |
await asyncio.sleep(remaining_time) | |
self.last_call_time = time.time() | |
class Costs(NamedTuple): | |
total_prompt_tokens: int | |
total_completion_tokens: int | |
total_cost: float | |
total_budget: float | |
class CostManager(metaclass=Singleton): | |
"""计算使用接口的开销""" | |
def __init__(self): | |
self.total_prompt_tokens = 0 | |
self.total_completion_tokens = 0 | |
self.total_cost = 0 | |
self.total_budget = 0 | |
def update_cost(self, prompt_tokens, completion_tokens, model): | |
""" | |
Update the total cost, prompt tokens, and completion tokens. | |
Args: | |
prompt_tokens (int): The number of tokens used in the prompt. | |
completion_tokens (int): The number of tokens used in the completion. | |
model (str): The model used for the API call. | |
""" | |
self.total_prompt_tokens += prompt_tokens | |
self.total_completion_tokens += completion_tokens | |
cost = ( | |
prompt_tokens * TOKEN_COSTS[model]["prompt"] | |
+ completion_tokens * TOKEN_COSTS[model]["completion"] | |
) / 1000 | |
self.total_cost += cost | |
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " | |
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}") | |
CONFIG.total_cost = self.total_cost | |
def get_total_prompt_tokens(self): | |
""" | |
Get the total number of prompt tokens. | |
Returns: | |
int: The total number of prompt tokens. | |
""" | |
return self.total_prompt_tokens | |
def get_total_completion_tokens(self): | |
""" | |
Get the total number of completion tokens. | |
Returns: | |
int: The total number of completion tokens. | |
""" | |
return self.total_completion_tokens | |
def get_total_cost(self): | |
""" | |
Get the total cost of API calls. | |
Returns: | |
float: The total cost of API calls. | |
""" | |
return self.total_cost | |
def get_costs(self) -> Costs: | |
"""获得所有开销""" | |
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) | |
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): | |
""" | |
Check https://platform.openai.com/examples for examples | |
""" | |
def __init__(self, proxy='', api_key=''): | |
self.proxy = proxy | |
self.api_key = api_key | |
self.__init_openai(CONFIG) | |
self.llm = openai | |
self.stops = None | |
self.model = CONFIG.openai_api_model | |
self._cost_manager = CostManager() | |
RateLimiter.__init__(self, rpm=self.rpm) | |
def __init_openai(self, config): | |
if self.proxy != '': | |
openai.proxy = self.proxy | |
else: | |
litellm.api_key = config.openai_api_key | |
if self.api_key != '': | |
litellm.api_key = self.api_key | |
else: | |
litellm.api_key = config.openai_api_key | |
if config.openai_api_base: | |
litellm.api_base = config.openai_api_base | |
if config.openai_api_type: | |
litellm.api_type = config.openai_api_type | |
litellm.api_version = config.openai_api_version | |
self.rpm = int(config.get("RPM", 10)) | |
async def _achat_completion_stream(self, messages: list[dict]) -> str: | |
response = await litellm.acompletion( | |
**self._cons_kwargs(messages), | |
stream=True | |
) | |
# create variables to collect the stream of chunks | |
collected_chunks = [] | |
collected_messages = [] | |
# iterate through the stream of events | |
async for chunk in response: | |
collected_chunks.append(chunk) # save the event response | |
chunk_message = chunk['choices'][0]['delta'] # extract the message | |
collected_messages.append(chunk_message) # save the message | |
if "content" in chunk_message: | |
print(chunk_message["content"], end="") | |
full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) | |
usage = self._calc_usage(messages, full_reply_content) | |
self._update_costs(usage) | |
return full_reply_content | |
def _cons_kwargs(self, messages: list[dict]) -> dict: | |
if CONFIG.openai_api_type == 'azure': | |
kwargs = { | |
"deployment_id": CONFIG.deployment_id, | |
"messages": messages, | |
"max_tokens": CONFIG.max_tokens_rsp, | |
"n": 1, | |
"stop": self.stops, | |
"temperature": 0.3 | |
} | |
else: | |
kwargs = { | |
"model": self.model, | |
"messages": messages, | |
"max_tokens": CONFIG.max_tokens_rsp, | |
"n": 1, | |
"stop": self.stops, | |
"temperature": 0.3 | |
} | |
return kwargs | |
async def _achat_completion(self, messages: list[dict]) -> dict: | |
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) | |
self._update_costs(rsp.get('usage')) | |
return rsp | |
def _chat_completion(self, messages: list[dict]) -> dict: | |
rsp = self.llm.ChatCompletion.create(**self._cons_kwargs(messages)) | |
self._update_costs(rsp) | |
return rsp | |
def completion(self, messages: list[dict]) -> dict: | |
# if isinstance(messages[0], Message): | |
# messages = self.messages_to_dict(messages) | |
return self._chat_completion(messages) | |
async def acompletion(self, messages: list[dict]) -> dict: | |
# if isinstance(messages[0], Message): | |
# messages = self.messages_to_dict(messages) | |
return await self._achat_completion(messages) | |
async def acompletion_text(self, messages: list[dict], stream=False) -> str: | |
"""when streaming, print each token in place.""" | |
if stream: | |
return await self._achat_completion_stream(messages) | |
rsp = await self._achat_completion(messages) | |
return self.get_choice_text(rsp) | |
def _calc_usage(self, messages: list[dict], rsp: str) -> dict: | |
usage = {} | |
prompt_tokens = count_message_tokens(messages, self.model) | |
completion_tokens = count_string_tokens(rsp, self.model) | |
usage['prompt_tokens'] = prompt_tokens | |
usage['completion_tokens'] = completion_tokens | |
return usage | |
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]: | |
"""返回完整JSON""" | |
split_batches = self.split_batches(batch) | |
all_results = [] | |
for small_batch in split_batches: | |
logger.info(small_batch) | |
await self.wait_if_needed(len(small_batch)) | |
future = [self.acompletion(prompt) for prompt in small_batch] | |
results = await asyncio.gather(*future) | |
logger.info(results) | |
all_results.extend(results) | |
return all_results | |
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]: | |
"""仅返回纯文本""" | |
raw_results = await self.acompletion_batch(batch) | |
results = [] | |
for idx, raw_result in enumerate(raw_results, start=1): | |
result = self.get_choice_text(raw_result) | |
results.append(result) | |
logger.info(f"Result of task {idx}: {result}") | |
return results | |
def _update_costs(self, usage: dict): | |
prompt_tokens = int(usage['prompt_tokens']) | |
completion_tokens = int(usage['completion_tokens']) | |
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) | |
def get_costs(self) -> Costs: | |
return self._cost_manager.get_costs() | |