hgagent / lagent /llms /anthropic_llm.py
Chisleep
update01
f8d0193
import asyncio
import json
import os
from typing import Dict, List, Optional, Union
import anthropic
import httpcore
import httpx
from anthropic import NOT_GIVEN
from requests.exceptions import ProxyError
from .base_api import AsyncBaseAPILLM, BaseAPILLM
class ClaudeAPI(BaseAPILLM):
is_api: bool = True
def __init__(
self,
model_type: str = 'claude-3-5-sonnet-20241022',
retry: int = 5,
key: Union[str, List[str]] = 'ENV',
proxies: Optional[Dict] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='user'),
],
temperature: float = NOT_GIVEN,
max_new_tokens: int = 512,
top_p: float = NOT_GIVEN,
top_k: int = NOT_GIVEN,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None,
):
super().__init__(
meta_template=meta_template,
model_type=model_type,
retry=retry,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
)
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key
if isinstance(key, str):
self.keys = [key]
else:
self.keys = list(set(key))
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys}
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0
self.model_type = model_type
self.proxies = proxies
def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.
Args:
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Returns:
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
gen_params = {**self.gen_params, **gen_params}
import nest_asyncio
nest_asyncio.apply()
async def run_async_tasks():
tasks = [
self._chat(self.template_parser(messages), **gen_params)
for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
]
return await asyncio.gather(*tasks)
try:
loop = asyncio.get_running_loop()
# If the event loop is already running, schedule the task
future = asyncio.ensure_future(run_async_tasks())
ret = loop.run_until_complete(future)
except RuntimeError:
# If no running event loop, start a new one
ret = asyncio.run(run_async_tasks())
return ret[0] if isinstance(inputs[0], dict) else ret
def generate_request_data(self, model_type, messages, gen_params):
"""
Generates the request data for different model types.
Args:
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
messages (list): The list of messages to be sent to the model.
gen_params (dict): The generation parameters.
json_mode (bool): Flag to determine if the response format should be JSON.
Returns:
tuple: A tuple containing the header and the request data.
"""
# Copy generation parameters to avoid modifying the original dictionary
gen_params = gen_params.copy()
# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''
gen_params.pop('repetition_penalty')
if 'stop_words' in gen_params:
gen_params['stop_sequences'] = gen_params.pop('stop_words')
# Common parameters processing
gen_params['max_tokens'] = max_tokens
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
system = None
if messages[0]['role'] == 'system':
system = messages.pop(0)
system = system['content']
for message in messages:
message.pop('name', None)
data = {'model': model_type, 'messages': messages, **gen_params}
if system:
data['system'] = system
return data
async def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(messages, list)
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params)
max_num_retries = 0
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
client = self.clients[key]
try:
response = await client.messages.create(**data)
response = json.loads(response.json())
return response['content'][0]['text'].strip()
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
print(f'API请求错误: {e}')
await asyncio.sleep(5)
except (httpcore.ProxyError, ProxyError) as e:
print(f'代理服务器错误: {e}')
await asyncio.sleep(5)
except httpx.TimeoutException as e:
print(f'请求超时: {e}')
await asyncio.sleep(5)
except KeyboardInterrupt:
raise
except Exception as error:
if error.body['error']['message'] == 'invalid x-api-key':
self.invalid_keys.add(key)
self.logger.warn(f'invalid key: {key}')
elif error.body['error']['type'] == 'overloaded_error':
await asyncio.sleep(5)
elif error.body['error']['message'] == 'Internal server error':
await asyncio.sleep(5)
elif error.body['error']['message'] == (
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to '
'upgrade or purchase credits.'
):
self.invalid_keys.add(key)
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}')
max_num_retries += 1
raise RuntimeError(
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
)
class AsyncClaudeAPI(AsyncBaseAPILLM):
is_api: bool = True
def __init__(
self,
model_type: str = 'claude-3-5-sonnet-20241022',
retry: int = 5,
key: Union[str, List[str]] = 'ENV',
proxies: Optional[Dict] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='user'),
],
temperature: float = NOT_GIVEN,
max_new_tokens: int = 512,
top_p: float = NOT_GIVEN,
top_k: int = NOT_GIVEN,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None,
):
super().__init__(
model_type=model_type,
retry=retry,
meta_template=meta_template,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
)
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key
if isinstance(key, str):
self.keys = [key]
else:
self.keys = list(set(key))
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys}
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0
self.model_type = model_type
self.proxies = proxies
async def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.
Args:
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Returns:
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
gen_params = {**self.gen_params, **gen_params}
tasks = [
self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
]
ret = await asyncio.gather(*tasks)
return ret[0] if isinstance(inputs[0], dict) else ret
def generate_request_data(self, model_type, messages, gen_params):
"""
Generates the request data for different model types.
Args:
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
messages (list): The list of messages to be sent to the model.
gen_params (dict): The generation parameters.
json_mode (bool): Flag to determine if the response format should be JSON.
Returns:
tuple: A tuple containing the header and the request data.
"""
# Copy generation parameters to avoid modifying the original dictionary
gen_params = gen_params.copy()
# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''
gen_params.pop('repetition_penalty')
if 'stop_words' in gen_params:
gen_params['stop_sequences'] = gen_params.pop('stop_words')
# Common parameters processing
gen_params['max_tokens'] = max_tokens
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
system = None
if messages[0]['role'] == 'system':
system = messages.pop(0)
system = system['content']
for message in messages:
message.pop('name', None)
data = {'model': model_type, 'messages': messages, **gen_params}
if system:
data['system'] = system
return data
async def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(messages, list)
messages = self.template_parser(messages)
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params)
max_num_retries = 0
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
client = self.clients[key]
try:
response = await client.messages.create(**data)
response = json.loads(response.json())
return response['content'][0]['text'].strip()
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
print(f'API请求错误: {e}')
await asyncio.sleep(5)
except (httpcore.ProxyError, ProxyError) as e:
print(f'代理服务器错误: {e}')
await asyncio.sleep(5)
except httpx.TimeoutException as e:
print(f'请求超时: {e}')
await asyncio.sleep(5)
except KeyboardInterrupt:
raise
except Exception as error:
if error.body['error']['message'] == 'invalid x-api-key':
self.invalid_keys.add(key)
self.logger.warn(f'invalid key: {key}')
elif error.body['error']['type'] == 'overloaded_error':
await asyncio.sleep(5)
elif error.body['error']['message'] == 'Internal server error':
await asyncio.sleep(5)
elif error.body['error']['message'] == (
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to'
' upgrade or purchase credits.'
):
self.invalid_keys.add(key)
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}')
else:
raise error
max_num_retries += 1
raise RuntimeError(
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
)