|
from typing import List, Optional |
|
from langchain.llms.base import LLM |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from langchain.llms.utils import enforce_stop_tokens |
|
from fastchat.conversation import (compute_skip_echo_len, |
|
get_default_conv_template) |
|
|
|
|
|
class ModelLoader(LLM): |
|
tokenizer: object = None |
|
model: object = None |
|
max_token: int = 10000 |
|
temperature: float = 0.1 |
|
top_p = 0.9 |
|
history = [] |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "ChatLLM" |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy() |
|
conv.append_message(conv.roles[0], prompt) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
inputs = self.tokenizer([prompt]) |
|
output_ids = self.model.generate( |
|
torch.as_tensor(inputs.input_ids).cuda(), |
|
do_sample=True, |
|
temperature=self.temperature, |
|
max_new_tokens=self.max_token, |
|
) |
|
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt) |
|
response = outputs[skip_echo_len:] |
|
if stop is not None: |
|
response = enforce_stop_tokens(response, stop) |
|
self.history = [[None, response]] |
|
|
|
return response |
|
|
|
|
|
def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B", |
|
trust_remote_code=True |
|
) |
|
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
self.model = self.model.eval() |
|
|