Spaces:
Sleeping
Sleeping
from lagent.llms import BaseAPIModel | |
from typing import List, Optional, Union | |
from litellm import completion | |
from lagent.schema import ModelStatusCode | |
from lagent.utils.util import filter_suffix | |
import os | |
from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM, | |
LMDeployClient, LMDeployServer) | |
internlm_server = dict(type=LMDeployServer, | |
path='internlm/internlm2_5-7b-chat', | |
model_name='internlm2', | |
meta_template=INTERNLM2_META, | |
top_p=0.8, | |
top_k=1, | |
temperature=0, | |
max_new_tokens=8192, | |
repetition_penalty=1.02, | |
stop_words=['<|im_end|>']) | |
internlm_client = dict(type=LMDeployClient, | |
model_name='internlm2_5-7b-chat', | |
url='http://127.0.0.1:23333', | |
meta_template=INTERNLM2_META, | |
top_p=0.8, | |
top_k=1, | |
temperature=0, | |
max_new_tokens=8192, | |
repetition_penalty=1.02, | |
stop_words=['<|im_end|>']) | |
internlm_hf = dict(type=HFTransformerCasualLM, | |
path='internlm/internlm2_5-7b-chat', | |
meta_template=INTERNLM2_META, | |
top_p=0.8, | |
top_k=None, | |
temperature=1e-6, | |
max_new_tokens=8192, | |
repetition_penalty=1.02, | |
stop_words=['<|im_end|>']) | |
# openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions | |
gpt4 = dict(type=GPTAPI, | |
model_type='gpt-4-turbo', | |
key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'), | |
openai_api_base=os.environ.get( | |
'OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'), | |
) | |
url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation' | |
qwen = dict(type=GPTAPI, | |
model_type='qwen-max-longcontext', | |
key=os.environ.get('QWEN_API_KEY', 'YOUR QWEN API KEY'), | |
openai_api_base=url, | |
meta_template=[ | |
dict(role='system', api_role='system'), | |
dict(role='user', api_role='user'), | |
dict(role='assistant', api_role='assistant'), | |
dict(role='environment', api_role='system') | |
], | |
top_p=0.8, | |
top_k=1, | |
temperature=0, | |
max_new_tokens=4096, | |
repetition_penalty=1.02, | |
stop_words=['<|im_end|>']) | |
internlm_silicon = dict(type=GPTAPI, | |
model_type='internlm/internlm2_5-7b-chat', | |
key=os.environ.get( | |
'SILICON_API_KEY', 'YOUR SILICON API KEY'), | |
openai_api_base='https://api.siliconflow.cn/v1/chat/completions', | |
meta_template=[ | |
dict(role='system', api_role='system'), | |
dict(role='user', api_role='user'), | |
dict(role='assistant', api_role='assistant'), | |
dict(role='environment', api_role='system') | |
], | |
top_p=0.8, | |
top_k=1, | |
temperature=0, | |
max_new_tokens=8192, | |
repetition_penalty=1.02, | |
stop_words=['<|im_end|>']) | |
class litellmCompletion(BaseAPIModel): | |
""" | |
Args: | |
path (str): The path to the model. | |
It could be one of the following options: | |
- i) A local directory path of a turbomind model which is | |
converted by `lmdeploy convert` command or download | |
from ii) and iii). | |
- ii) The model_id of a lmdeploy-quantized model hosted | |
inside a model repo on huggingface.co, such as | |
"InternLM/internlm-chat-20b-4bit", | |
"lmdeploy/llama2-chat-70b-4bit", etc. | |
- iii) The model_id of a model hosted inside a model repo | |
on huggingface.co, such as "internlm/internlm-chat-7b", | |
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
and so on. | |
model_name (str): needed when model_path is a pytorch model on | |
huggingface.co, such as "internlm-chat-7b", | |
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
tp (int): tensor parallel | |
pipeline_cfg (dict): config of pipeline | |
""" | |
def __init__(self, | |
path='', | |
model_name="command-r", | |
**kwargs): | |
self.model_name = model_name | |
super().__init__(path, **kwargs) | |
def generate(self, | |
inputs: Union[str, List[str]], | |
do_preprocess: bool = None, | |
skip_special_tokens: bool = False, | |
**kwargs): | |
"""Return the chat completions in non-stream mode. | |
Args: | |
inputs (Union[str, List[str]]): input texts to be completed. | |
do_preprocess (bool): whether pre-process the messages. Default to | |
True, which means chat_template will be applied. | |
skip_special_tokens (bool): Whether or not to remove special tokens | |
in the decoding. Default to be False. | |
Returns: | |
(a list of/batched) text/chat completion | |
""" | |
batched = True | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
prompts = inputs | |
messages = [{"role": "user", "content": prompt}for prompt in prompts] | |
gen_params = self.update_gen_params(**kwargs) | |
response = completion(model=self.model_name, messages=messages) | |
response = [resp.message.content for resp in response.choices] | |
# remove stop_words | |
response = filter_suffix(response, self.gen_params.get('stop_words')) | |
if batched: | |
return response | |
return response[0] | |
def stream_chat(self, | |
inputs: List[dict], | |
stream: bool = True, | |
ignore_eos: bool = False, | |
skip_special_tokens: Optional[bool] = False, | |
timeout: int = 30, | |
**kwargs): | |
"""Start a new round conversation of a session. Return the chat | |
completions in stream mode. | |
Args: | |
session_id (int): the identical id of a session | |
inputs (List[dict]): user's inputs in this round conversation | |
sequence_start (bool): start flag of a session | |
sequence_end (bool): end flag of a session | |
stream (bool): return in a streaming format if enabled | |
ignore_eos (bool): indicator for ignoring eos | |
skip_special_tokens (bool): Whether or not to remove special tokens | |
in the decoding. Default to be False. | |
timeout (int): max time to wait for response | |
Returns: | |
tuple(Status, str, int): status, text/chat completion, | |
generated token number | |
""" | |
gen_params = self.update_gen_params(**kwargs) | |
max_new_tokens = gen_params.pop('max_new_tokens') | |
gen_params.update(max_tokens=max_new_tokens) | |
resp = '' | |
finished = False | |
stop_words = gen_params.get('stop_words') | |
if stop_words is None: | |
stop_words = [] | |
messages = self.template_parser._prompt2api(inputs) | |
for text in completion( | |
self.model_name, | |
messages, | |
stream=stream, | |
**gen_params): | |
if not text.choices[0].delta.content: | |
continue | |
resp += text.choices[0].delta.content | |
if not resp: | |
continue | |
# remove stop_words | |
for sw in stop_words: | |
if sw in resp: | |
resp = filter_suffix(resp, stop_words) | |
finished = True | |
break | |
yield ModelStatusCode.STREAM_ING, resp, None | |
if finished: | |
break | |
yield ModelStatusCode.END, resp, None | |
litellm_completion = dict(type=litellmCompletion, | |
# model_name="deepseek/deepseek-chat", | |
meta_template=[ | |
dict(role='system', api_role='system'), | |
dict(role='user', api_role='user'), | |
dict(role='assistant', api_role='assistant'), | |
dict(role='environment', api_role='system') | |
] | |
) | |