Spaces:
Runtime error
Runtime error
import torch | |
import os | |
#os.environ['TRANSFORMERS_CACHE'] = "./.cache" | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer | |
from vigogne.preprocess import generate_inference_chat_prompt | |
class CaesarFrenchLLM: | |
def __init__(self) -> None: | |
self.history = [] | |
base_model_name_or_path = "bofenghuang/vigogne-2-7b-chat" | |
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False,) | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.model = AutoModelForCausalLM.from_pretrained( | |
base_model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
# load_in_8bit=True, | |
# trust_remote_code=True, | |
# low_cpu_mem_usage=True, | |
) | |
# lora_model_name_or_path = "" | |
# model = PeftModel.from_pretrained(model, lora_model_name_or_path) | |
self.model.eval() | |
if torch.__version__ >= "2": | |
self.model = torch.compile(self.model) | |
self.streamer = TextStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
def infer(self,user_query,temperature=0.1,top_p=1.0,top_k=0,max_new_tokens=512,**kwargs,): | |
prompt = generate_inference_chat_prompt(user_query, tokenizer=self.tokenizer) | |
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"].to(self.model.device) | |
input_length = input_ids.shape[1] | |
generated_outputs = self.model.generate( | |
input_ids=input_ids, | |
generation_config=GenerationConfig( | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id, | |
**kwargs, | |
), | |
streamer=self.streamer, | |
return_dict_in_generate=True, | |
) | |
generated_tokens = generated_outputs.sequences[0, input_length:] | |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return generated_text | |
def chat(self,user_input,**kwargs): | |
print(f">> <|user|>: {user_input}") | |
print(">> <|assistant|>: ", end="") | |
model_response = self.infer([*self.history, [user_input, ""]], **kwargs) | |
self.history.append([user_input, model_response]) | |
return self.history[-1][1] | |
# print(f">> <|assistant|>: {history[-1][1]}") | |