Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import os | |
from typing import Dict, List, Optional, Union | |
import anthropic | |
import httpcore | |
import httpx | |
from anthropic import NOT_GIVEN | |
from requests.exceptions import ProxyError | |
from .base_api import AsyncBaseAPILLM, BaseAPILLM | |
class ClaudeAPI(BaseAPILLM): | |
is_api: bool = True | |
def __init__( | |
self, | |
model_type: str = 'claude-3-5-sonnet-20241022', | |
retry: int = 5, | |
key: Union[str, List[str]] = 'ENV', | |
proxies: Optional[Dict] = None, | |
meta_template: Optional[Dict] = [ | |
dict(role='system', api_role='system'), | |
dict(role='user', api_role='user'), | |
dict(role='assistant', api_role='assistant'), | |
dict(role='environment', api_role='user'), | |
], | |
temperature: float = NOT_GIVEN, | |
max_new_tokens: int = 512, | |
top_p: float = NOT_GIVEN, | |
top_k: int = NOT_GIVEN, | |
repetition_penalty: float = 0.0, | |
stop_words: Union[List[str], str] = None, | |
): | |
super().__init__( | |
meta_template=meta_template, | |
model_type=model_type, | |
retry=retry, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
stop_words=stop_words, | |
) | |
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key | |
if isinstance(key, str): | |
self.keys = [key] | |
else: | |
self.keys = list(set(key)) | |
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} | |
# record invalid keys and skip them when requesting API | |
# - keys have insufficient_quota | |
self.invalid_keys = set() | |
self.key_ctr = 0 | |
self.model_type = model_type | |
self.proxies = proxies | |
def chat( | |
self, | |
inputs: Union[List[dict], List[List[dict]]], | |
session_ids: Union[int, List[int]] = None, | |
**gen_params, | |
) -> Union[str, List[str]]: | |
"""Generate responses given the contexts. | |
Args: | |
inputs (Union[List[dict], List[List[dict]]]): a list of messages | |
or list of lists of messages | |
gen_params: additional generation configuration | |
Returns: | |
Union[str, List[str]]: generated string(s) | |
""" | |
assert isinstance(inputs, list) | |
gen_params = {**self.gen_params, **gen_params} | |
import nest_asyncio | |
nest_asyncio.apply() | |
async def run_async_tasks(): | |
tasks = [ | |
self._chat(self.template_parser(messages), **gen_params) | |
for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) | |
] | |
return await asyncio.gather(*tasks) | |
try: | |
loop = asyncio.get_running_loop() | |
# If the event loop is already running, schedule the task | |
future = asyncio.ensure_future(run_async_tasks()) | |
ret = loop.run_until_complete(future) | |
except RuntimeError: | |
# If no running event loop, start a new one | |
ret = asyncio.run(run_async_tasks()) | |
return ret[0] if isinstance(inputs[0], dict) else ret | |
def generate_request_data(self, model_type, messages, gen_params): | |
""" | |
Generates the request data for different model types. | |
Args: | |
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). | |
messages (list): The list of messages to be sent to the model. | |
gen_params (dict): The generation parameters. | |
json_mode (bool): Flag to determine if the response format should be JSON. | |
Returns: | |
tuple: A tuple containing the header and the request data. | |
""" | |
# Copy generation parameters to avoid modifying the original dictionary | |
gen_params = gen_params.copy() | |
# Hold out 100 tokens due to potential errors in token calculation | |
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) | |
if max_tokens <= 0: | |
return '', '' | |
gen_params.pop('repetition_penalty') | |
if 'stop_words' in gen_params: | |
gen_params['stop_sequences'] = gen_params.pop('stop_words') | |
# Common parameters processing | |
gen_params['max_tokens'] = max_tokens | |
gen_params.pop('skip_special_tokens', None) | |
gen_params.pop('session_id', None) | |
system = None | |
if messages[0]['role'] == 'system': | |
system = messages.pop(0) | |
system = system['content'] | |
for message in messages: | |
message.pop('name', None) | |
data = {'model': model_type, 'messages': messages, **gen_params} | |
if system: | |
data['system'] = system | |
return data | |
async def _chat(self, messages: List[dict], **gen_params) -> str: | |
"""Generate completion from a list of templates. | |
Args: | |
messages (List[dict]): a list of prompt dictionaries | |
gen_params: additional generation configuration | |
Returns: | |
str: The generated string. | |
""" | |
assert isinstance(messages, list) | |
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) | |
max_num_retries = 0 | |
while max_num_retries < self.retry: | |
if len(self.invalid_keys) == len(self.keys): | |
raise RuntimeError('All keys have insufficient quota.') | |
# find the next valid key | |
while True: | |
self.key_ctr += 1 | |
if self.key_ctr == len(self.keys): | |
self.key_ctr = 0 | |
if self.keys[self.key_ctr] not in self.invalid_keys: | |
break | |
key = self.keys[self.key_ctr] | |
client = self.clients[key] | |
try: | |
response = await client.messages.create(**data) | |
response = json.loads(response.json()) | |
return response['content'][0]['text'].strip() | |
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: | |
print(f'API请求错误: {e}') | |
await asyncio.sleep(5) | |
except (httpcore.ProxyError, ProxyError) as e: | |
print(f'代理服务器错误: {e}') | |
await asyncio.sleep(5) | |
except httpx.TimeoutException as e: | |
print(f'请求超时: {e}') | |
await asyncio.sleep(5) | |
except KeyboardInterrupt: | |
raise | |
except Exception as error: | |
if error.body['error']['message'] == 'invalid x-api-key': | |
self.invalid_keys.add(key) | |
self.logger.warn(f'invalid key: {key}') | |
elif error.body['error']['type'] == 'overloaded_error': | |
await asyncio.sleep(5) | |
elif error.body['error']['message'] == 'Internal server error': | |
await asyncio.sleep(5) | |
elif error.body['error']['message'] == ( | |
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to ' | |
'upgrade or purchase credits.' | |
): | |
self.invalid_keys.add(key) | |
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') | |
max_num_retries += 1 | |
raise RuntimeError( | |
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' | |
) | |
class AsyncClaudeAPI(AsyncBaseAPILLM): | |
is_api: bool = True | |
def __init__( | |
self, | |
model_type: str = 'claude-3-5-sonnet-20241022', | |
retry: int = 5, | |
key: Union[str, List[str]] = 'ENV', | |
proxies: Optional[Dict] = None, | |
meta_template: Optional[Dict] = [ | |
dict(role='system', api_role='system'), | |
dict(role='user', api_role='user'), | |
dict(role='assistant', api_role='assistant'), | |
dict(role='environment', api_role='user'), | |
], | |
temperature: float = NOT_GIVEN, | |
max_new_tokens: int = 512, | |
top_p: float = NOT_GIVEN, | |
top_k: int = NOT_GIVEN, | |
repetition_penalty: float = 0.0, | |
stop_words: Union[List[str], str] = None, | |
): | |
super().__init__( | |
model_type=model_type, | |
retry=retry, | |
meta_template=meta_template, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
stop_words=stop_words, | |
) | |
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key | |
if isinstance(key, str): | |
self.keys = [key] | |
else: | |
self.keys = list(set(key)) | |
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} | |
# record invalid keys and skip them when requesting API | |
# - keys have insufficient_quota | |
self.invalid_keys = set() | |
self.key_ctr = 0 | |
self.model_type = model_type | |
self.proxies = proxies | |
async def chat( | |
self, | |
inputs: Union[List[dict], List[List[dict]]], | |
session_ids: Union[int, List[int]] = None, | |
**gen_params, | |
) -> Union[str, List[str]]: | |
"""Generate responses given the contexts. | |
Args: | |
inputs (Union[List[dict], List[List[dict]]]): a list of messages | |
or list of lists of messages | |
gen_params: additional generation configuration | |
Returns: | |
Union[str, List[str]]: generated string(s) | |
""" | |
assert isinstance(inputs, list) | |
gen_params = {**self.gen_params, **gen_params} | |
tasks = [ | |
self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) | |
] | |
ret = await asyncio.gather(*tasks) | |
return ret[0] if isinstance(inputs[0], dict) else ret | |
def generate_request_data(self, model_type, messages, gen_params): | |
""" | |
Generates the request data for different model types. | |
Args: | |
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). | |
messages (list): The list of messages to be sent to the model. | |
gen_params (dict): The generation parameters. | |
json_mode (bool): Flag to determine if the response format should be JSON. | |
Returns: | |
tuple: A tuple containing the header and the request data. | |
""" | |
# Copy generation parameters to avoid modifying the original dictionary | |
gen_params = gen_params.copy() | |
# Hold out 100 tokens due to potential errors in token calculation | |
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) | |
if max_tokens <= 0: | |
return '', '' | |
gen_params.pop('repetition_penalty') | |
if 'stop_words' in gen_params: | |
gen_params['stop_sequences'] = gen_params.pop('stop_words') | |
# Common parameters processing | |
gen_params['max_tokens'] = max_tokens | |
gen_params.pop('skip_special_tokens', None) | |
gen_params.pop('session_id', None) | |
system = None | |
if messages[0]['role'] == 'system': | |
system = messages.pop(0) | |
system = system['content'] | |
for message in messages: | |
message.pop('name', None) | |
data = {'model': model_type, 'messages': messages, **gen_params} | |
if system: | |
data['system'] = system | |
return data | |
async def _chat(self, messages: List[dict], **gen_params) -> str: | |
"""Generate completion from a list of templates. | |
Args: | |
messages (List[dict]): a list of prompt dictionaries | |
gen_params: additional generation configuration | |
Returns: | |
str: The generated string. | |
""" | |
assert isinstance(messages, list) | |
messages = self.template_parser(messages) | |
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) | |
max_num_retries = 0 | |
while max_num_retries < self.retry: | |
if len(self.invalid_keys) == len(self.keys): | |
raise RuntimeError('All keys have insufficient quota.') | |
# find the next valid key | |
while True: | |
self.key_ctr += 1 | |
if self.key_ctr == len(self.keys): | |
self.key_ctr = 0 | |
if self.keys[self.key_ctr] not in self.invalid_keys: | |
break | |
key = self.keys[self.key_ctr] | |
client = self.clients[key] | |
try: | |
response = await client.messages.create(**data) | |
response = json.loads(response.json()) | |
return response['content'][0]['text'].strip() | |
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: | |
print(f'API请求错误: {e}') | |
await asyncio.sleep(5) | |
except (httpcore.ProxyError, ProxyError) as e: | |
print(f'代理服务器错误: {e}') | |
await asyncio.sleep(5) | |
except httpx.TimeoutException as e: | |
print(f'请求超时: {e}') | |
await asyncio.sleep(5) | |
except KeyboardInterrupt: | |
raise | |
except Exception as error: | |
if error.body['error']['message'] == 'invalid x-api-key': | |
self.invalid_keys.add(key) | |
self.logger.warn(f'invalid key: {key}') | |
elif error.body['error']['type'] == 'overloaded_error': | |
await asyncio.sleep(5) | |
elif error.body['error']['message'] == 'Internal server error': | |
await asyncio.sleep(5) | |
elif error.body['error']['message'] == ( | |
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to' | |
' upgrade or purchase credits.' | |
): | |
self.invalid_keys.add(key) | |
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') | |
else: | |
raise error | |
max_num_retries += 1 | |
raise RuntimeError( | |
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' | |
) | |