Spaces:
Runtime error
Runtime error
File size: 6,779 Bytes
d2ca3e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import copy
import logging
from typing import List, Optional, Union
from lagent.llms.base_llm import BaseModel
from lagent.schema import ModelStatusCode
from lagent.utils.util import filter_suffix
class LMDeployServer(BaseModel):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int): tensor parallel
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
tp: int = 1,
log_level: str = 'WARNING',
serve_cfg=dict(),
**kwargs):
super().__init__(path=path, **kwargs)
self.model_name = model_name
# TODO get_logger issue in multi processing
import lmdeploy
self.client = lmdeploy.serve(
model_path=self.path,
model_name=model_name,
server_name=server_name,
server_port=server_port,
tp=tp,
log_level=log_level,
**serve_cfg)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
"""
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
resp = [''] * len(inputs)
for text in self.client.completions_v1(
self.model_name,
inputs,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
resp[i] + item['text']
for i, item in enumerate(text['choices'])
]
# remove stop_words
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
if not batched:
return resp[0]
return resp
def stream_chat(self,
inputs: List[dict],
session_id=0,
sequence_start: bool = True,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)
resp = ''
finished = False
stop_words = self.gen_params.get('stop_words')
for text in self.client.completions_v1(
self.model_name,
prompt,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
if not resp:
continue
# remove stop_words
for sw in stop_words:
if sw in resp:
resp = filter_suffix(resp, stop_words)
finished = True
break
yield ModelStatusCode.STREAM_ING, resp, None
if finished:
break
yield ModelStatusCode.END, resp, None |