cnzzx's picture
update
d2ca3e2
raw
history blame
6.78 kB
import copy
import logging
from typing import List, Optional, Union
from lagent.llms.base_llm import BaseModel
from lagent.schema import ModelStatusCode
from lagent.utils.util import filter_suffix
class LMDeployServer(BaseModel):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int): tensor parallel
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
tp: int = 1,
log_level: str = 'WARNING',
serve_cfg=dict(),
**kwargs):
super().__init__(path=path, **kwargs)
self.model_name = model_name
# TODO get_logger issue in multi processing
import lmdeploy
self.client = lmdeploy.serve(
model_path=self.path,
model_name=model_name,
server_name=server_name,
server_port=server_port,
tp=tp,
log_level=log_level,
**serve_cfg)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
"""
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
resp = [''] * len(inputs)
for text in self.client.completions_v1(
self.model_name,
inputs,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
resp[i] + item['text']
for i, item in enumerate(text['choices'])
]
# remove stop_words
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
if not batched:
return resp[0]
return resp
def stream_chat(self,
inputs: List[dict],
session_id=0,
sequence_start: bool = True,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)
resp = ''
finished = False
stop_words = self.gen_params.get('stop_words')
for text in self.client.completions_v1(
self.model_name,
prompt,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
if not resp:
continue
# remove stop_words
for sw in stop_words:
if sw in resp:
resp = filter_suffix(resp, stop_words)
finished = True
break
yield ModelStatusCode.STREAM_ING, resp, None
if finished:
break
yield ModelStatusCode.END, resp, None