handsomeguy001
support litellm
8723eb9
raw
history blame contribute delete
8.97 kB
from lagent.llms import BaseAPIModel
from typing import List, Optional, Union
from litellm import completion
from lagent.schema import ModelStatusCode
from lagent.utils.util import filter_suffix
import os
from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM,
LMDeployClient, LMDeployServer)
internlm_server = dict(type=LMDeployServer,
path='internlm/internlm2_5-7b-chat',
model_name='internlm2',
meta_template=INTERNLM2_META,
top_p=0.8,
top_k=1,
temperature=0,
max_new_tokens=8192,
repetition_penalty=1.02,
stop_words=['<|im_end|>'])
internlm_client = dict(type=LMDeployClient,
model_name='internlm2_5-7b-chat',
url='http://127.0.0.1:23333',
meta_template=INTERNLM2_META,
top_p=0.8,
top_k=1,
temperature=0,
max_new_tokens=8192,
repetition_penalty=1.02,
stop_words=['<|im_end|>'])
internlm_hf = dict(type=HFTransformerCasualLM,
path='internlm/internlm2_5-7b-chat',
meta_template=INTERNLM2_META,
top_p=0.8,
top_k=None,
temperature=1e-6,
max_new_tokens=8192,
repetition_penalty=1.02,
stop_words=['<|im_end|>'])
# openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions
gpt4 = dict(type=GPTAPI,
model_type='gpt-4-turbo',
key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'),
openai_api_base=os.environ.get(
'OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'),
)
url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
qwen = dict(type=GPTAPI,
model_type='qwen-max-longcontext',
key=os.environ.get('QWEN_API_KEY', 'YOUR QWEN API KEY'),
openai_api_base=url,
meta_template=[
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system')
],
top_p=0.8,
top_k=1,
temperature=0,
max_new_tokens=4096,
repetition_penalty=1.02,
stop_words=['<|im_end|>'])
internlm_silicon = dict(type=GPTAPI,
model_type='internlm/internlm2_5-7b-chat',
key=os.environ.get(
'SILICON_API_KEY', 'YOUR SILICON API KEY'),
openai_api_base='https://api.siliconflow.cn/v1/chat/completions',
meta_template=[
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system')
],
top_p=0.8,
top_k=1,
temperature=0,
max_new_tokens=8192,
repetition_penalty=1.02,
stop_words=['<|im_end|>'])
class litellmCompletion(BaseAPIModel):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
tp (int): tensor parallel
pipeline_cfg (dict): config of pipeline
"""
def __init__(self,
path='',
model_name="command-r",
**kwargs):
self.model_name = model_name
super().__init__(path, **kwargs)
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
batched = True
if isinstance(inputs, str):
inputs = [inputs]
prompts = inputs
messages = [{"role": "user", "content": prompt}for prompt in prompts]
gen_params = self.update_gen_params(**kwargs)
response = completion(model=self.model_name, messages=messages)
response = [resp.message.content for resp in response.choices]
# remove stop_words
response = filter_suffix(response, self.gen_params.get('stop_words'))
if batched:
return response
return response[0]
def stream_chat(self,
inputs: List[dict],
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
resp = ''
finished = False
stop_words = gen_params.get('stop_words')
if stop_words is None:
stop_words = []
messages = self.template_parser._prompt2api(inputs)
for text in completion(
self.model_name,
messages,
stream=stream,
**gen_params):
if not text.choices[0].delta.content:
continue
resp += text.choices[0].delta.content
if not resp:
continue
# remove stop_words
for sw in stop_words:
if sw in resp:
resp = filter_suffix(resp, stop_words)
finished = True
break
yield ModelStatusCode.STREAM_ING, resp, None
if finished:
break
yield ModelStatusCode.END, resp, None
litellm_completion = dict(type=litellmCompletion,
# model_name="deepseek/deepseek-chat",
meta_template=[
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system')
]
)