Spaces:
Runtime error
Runtime error
import torch | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = "T:/CaesarLLModel/.cache" | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextStreamer | |
from vigogne.preprocess import generate_inference_chat_prompt | |
if __name__ == "__main__": | |
base_model_name_or_path = "bofenghuang/vigogne-2-7b-chat" | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False,) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
offload_folder="T:/CaesarLLModel/.cache/offload" | |
# load_in_8bit=True, | |
# trust_remote_code=True, | |
# low_cpu_mem_usage=True, | |
) | |
# lora_model_name_or_path = "" | |
# model = PeftModel.from_pretrained(model, lora_model_name_or_path) | |
model.eval() | |
#if torch.__version__ >= "2": | |
# model = torch.compile(model) | |
streamer = TextStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
def infer( user_query,temperature=0.1,top_p=1.0,top_k=0,max_new_tokens=512,**kwargs,): | |
prompt = generate_inference_chat_prompt(user_query, tokenizer=tokenizer) | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model.device) | |
input_length = input_ids.shape[1] | |
generated_outputs = model.generate( | |
input_ids=input_ids, | |
generation_config=GenerationConfig( | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
**kwargs, | |
), | |
streamer=streamer, | |
return_dict_in_generate=True, | |
) | |
generated_tokens = generated_outputs.sequences[0, input_length:] | |
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return generated_text | |
def chat(**kwargs): | |
history = [] | |
while True: | |
user_input = input(">> <|user|>: ") | |
print(">> <|assistant|>: ", end="") | |
model_response = infer([*history, [user_input, ""]], **kwargs) | |
history.append([user_input, model_response]) | |
return history[-1][1] | |
# print(f">> <|assistant|>: {history[-1][1]}") | |
chat() |