File size: 2,037 Bytes
b5e593e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from typing import List, Optional
from langchain.llms.base import LLM
import torch
from transformers import AutoModel, AutoTokenizer
from langchain.llms.utils import enforce_stop_tokens
from fastchat.conversation import (compute_skip_echo_len,
get_default_conv_template)
class ModelLoader(LLM):
tokenizer: object = None
model: object = None
max_token: int = 10000
temperature: float = 0.1
top_p = 0.9
history = []
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatLLM"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = self.tokenizer([prompt])
output_ids = self.model.generate(
torch.as_tensor(inputs.input_ids).cuda(),
do_sample=True,
temperature=self.temperature,
max_new_tokens=self.max_token,
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt)
response = outputs[skip_echo_len:]
if stop is not None:
response = enforce_stop_tokens(response, stop)
self.history = [[None, response]]
return response
def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"):
self.tokenizer = AutoTokenizer.from_pretrained(
"/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B",
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = self.model.eval()
|