ceshi2 / llmLoader.py
zxcgqq's picture
Upload 8 files
b5e593e
from typing import List, Optional
from langchain.llms.base import LLM
import torch
from transformers import AutoModel, AutoTokenizer
from langchain.llms.utils import enforce_stop_tokens
from fastchat.conversation import (compute_skip_echo_len,
get_default_conv_template)
class ModelLoader(LLM):
tokenizer: object = None
model: object = None
max_token: int = 10000
temperature: float = 0.1
top_p = 0.9
history = []
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatLLM"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
conv = get_default_conv_template("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8").copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = self.tokenizer([prompt])
output_ids = self.model.generate(
torch.as_tensor(inputs.input_ids).cuda(),
do_sample=True,
temperature=self.temperature,
max_new_tokens=self.max_token,
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
skip_echo_len = compute_skip_echo_len("/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8", conv, prompt)
response = outputs[skip_echo_len:]
if stop is not None:
response = enforce_stop_tokens(response, stop)
self.history = [[None, response]]
return response
def load_model(self, model_name_or_path: str = "/DATA/gpt/lang/model_cache/THUDM/chatglm-6b-int8"):
self.tokenizer = AutoTokenizer.from_pretrained(
"/DATA/gpt/mingpt-7b/MiniGPT-4-LLaMA-7B",
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = self.model.eval()