import copy import logging from typing import List, Optional, Union from lagent.llms.base_llm import BaseModel from lagent.schema import ModelStatusCode from lagent.utils.util import filter_suffix class LMDeployServer(BaseModel): """ Args: path (str): The path to the model. It could be one of the following options: - i) A local directory path of a turbomind model which is converted by `lmdeploy convert` command or download from ii) and iii). - ii) The model_id of a lmdeploy-quantized model hosted inside a model repo on huggingface.co, such as "InternLM/internlm-chat-20b-4bit", "lmdeploy/llama2-chat-70b-4bit", etc. - iii) The model_id of a model hosted inside a model repo on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. model_name (str): needed when model_path is a pytorch model on huggingface.co, such as "internlm-chat-7b", "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. server_name (str): host ip for serving server_port (int): server port tp (int): tensor parallel log_level (str): set log level whose value among [CRITICAL, ERROR, WARNING, INFO, DEBUG] """ def __init__(self, path: str, model_name: Optional[str] = None, server_name: str = '0.0.0.0', server_port: int = 23333, tp: int = 1, log_level: str = 'WARNING', serve_cfg=dict(), **kwargs): super().__init__(path=path, **kwargs) self.model_name = model_name # TODO get_logger issue in multi processing import lmdeploy self.client = lmdeploy.serve( model_path=self.path, model_name=model_name, server_name=server_name, server_port=server_port, tp=tp, log_level=log_level, **serve_cfg) def generate(self, inputs: Union[str, List[str]], session_id: int = 2967, sequence_start: bool = True, sequence_end: bool = True, ignore_eos: bool = False, skip_special_tokens: Optional[bool] = False, timeout: int = 30, **kwargs) -> List[str]: """Start a new round conversation of a session. Return the chat completions in non-stream mode. Args: inputs (str, List[str]): user's prompt(s) in this round session_id (int): the identical id of a session sequence_start (bool): start flag of a session sequence_end (bool): end flag of a session ignore_eos (bool): indicator for ignoring eos skip_special_tokens (bool): Whether or not to remove special tokens in the decoding. Default to be False. timeout (int): max time to wait for response Returns: (a list of/batched) text/chat completion """ batched = True if isinstance(inputs, str): inputs = [inputs] batched = False gen_params = self.update_gen_params(**kwargs) max_new_tokens = gen_params.pop('max_new_tokens') gen_params.update(max_tokens=max_new_tokens) resp = [''] * len(inputs) for text in self.client.completions_v1( self.model_name, inputs, session_id=session_id, sequence_start=sequence_start, sequence_end=sequence_end, stream=False, ignore_eos=ignore_eos, skip_special_tokens=skip_special_tokens, timeout=timeout, **gen_params): resp = [ resp[i] + item['text'] for i, item in enumerate(text['choices']) ] # remove stop_words resp = filter_suffix(resp, self.gen_params.get('stop_words')) if not batched: return resp[0] return resp def stream_chat(self, inputs: List[dict], session_id=0, sequence_start: bool = True, sequence_end: bool = True, stream: bool = True, ignore_eos: bool = False, skip_special_tokens: Optional[bool] = False, timeout: int = 30, **kwargs): """Start a new round conversation of a session. Return the chat completions in stream mode. Args: session_id (int): the identical id of a session inputs (List[dict]): user's inputs in this round conversation sequence_start (bool): start flag of a session sequence_end (bool): end flag of a session stream (bool): return in a streaming format if enabled ignore_eos (bool): indicator for ignoring eos skip_special_tokens (bool): Whether or not to remove special tokens in the decoding. Default to be False. timeout (int): max time to wait for response Returns: tuple(Status, str, int): status, text/chat completion, generated token number """ gen_params = self.update_gen_params(**kwargs) max_new_tokens = gen_params.pop('max_new_tokens') gen_params.update(max_tokens=max_new_tokens) prompt = self.template_parser(inputs) resp = '' finished = False stop_words = self.gen_params.get('stop_words') for text in self.client.completions_v1( self.model_name, prompt, session_id=session_id, sequence_start=sequence_start, sequence_end=sequence_end, stream=stream, ignore_eos=ignore_eos, skip_special_tokens=skip_special_tokens, timeout=timeout, **gen_params): resp += text['choices'][0]['text'] if not resp: continue # remove stop_words for sw in stop_words: if sw in resp: resp = filter_suffix(resp, stop_words) finished = True break yield ModelStatusCode.STREAM_ING, resp, None if finished: break yield ModelStatusCode.END, resp, None