from typing import List, Optional from langchain.llms.base import LLM import torch from transformers import AutoModel, AutoTokenizer from langchain.llms.utils import enforce_stop_tokens from fastchat.conversation import (compute_skip_echo_len, get_default_conv_template) class ModelLoader(LLM): tokenizer: object = None model: object = None max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 history = [] def __init__(self): super().__init__() @property def _llm_type(self) -> str: return "ChatLLM" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy() conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = self.tokenizer([prompt]) output_ids = self.model.generate( torch.as_tensor(inputs.input_ids).cuda(), do_sample=True, temperature=self.temperature, max_new_tokens=self.max_token, ) outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt) response = outputs[skip_echo_len:] if stop is not None: response = enforce_stop_tokens(response, stop) self.history = [[None, response]] return response def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"): self.tokenizer = AutoTokenizer.from_pretrained( "/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B", trust_remote_code=True ) self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) self.model = self.model.eval()