Spaces:
Runtime error
Runtime error
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import pandas as pd | |
import torch | |
import pprint | |
import json | |
import gradio as gr | |
# Read and preprocess data | |
TURN_TEMPLATE = "<|im_start|>{role}\n{content}<eos>\n" | |
TURN_PREFIX = "<|im_start|>{role}\n" | |
start_token = "<|im_start|>" | |
end_token = "<eos>" | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) # <-- 'cpu' π€ | |
# Load your fine-tuned model and tokenizer | |
surMuy_model_id = "AingHongsin/SurMuy_v1_512512201" | |
model = AutoModelForCausalLM.from_pretrained(surMuy_model_id, | |
device_map={'': 0}, | |
revision="main", | |
torch_dtype=torch.bfloat16, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(surMuy_model_id) | |
model.eval() | |
model.to(zero.device) | |
def deFormat(data): | |
# Find the start and end indices of each turn in the data | |
turn_indices = [] | |
start_index = data.find(start_token) | |
while start_index != -1: | |
end_index = data.find(end_token, start_index) | |
if end_index != -1: | |
turn_indices.append((start_index, end_index + len(end_token))) | |
else: | |
turn_indices.append((start_index, len(data))) | |
start_index = data.find(start_token, start_index + len(start_token)) | |
# Extract role and content for each turn | |
turns = [] | |
for i in range(len(turn_indices)): | |
turn_start, turn_end = turn_indices[i] | |
turn_data = data[turn_start:turn_end].strip() | |
# Extract role and content from turn data using TURN_TEMPLATE | |
role_start = len(start_token) | |
role_end = turn_data.find("\n", role_start) | |
role = turn_data[role_start:role_end] | |
content_start = role_end + 1 | |
content = turn_data[content_start:] | |
turns.append({'role': role, 'content': content}) | |
return turns | |
def generate(text): | |
device = zero.device | |
messages = [ | |
{"role": "user", "content": text} | |
] | |
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) | |
# print(tokenizer.convert_ids_to_tokens(encodeds[0])) | |
model_inputs = encodeds.to(device) | |
model.to(device) | |
generated_ids = model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=tokenizer.pad_token_id) | |
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
predict_answer = deFormat(decoded[0]) | |
return predict_answer | |
def beam_search(model, start_token, beam_width=3, max_length=10): | |
sequences = [[start_token, 0.0]] # Initialize with start_token and score 0.0 | |
while len(sequences[0][0]) < max_length: | |
all_candidates = [] | |
for seq, score in sequences: | |
if seq[-1] == '<end>': # Assuming '<end>' is the end token | |
all_candidates.append((seq, score)) | |
continue | |
next_token_probs = model.predict_next(seq) | |
for token, prob in enumerate(next_token_probs): | |
candidate = (seq + [token], score - np.log(prob)) | |
all_candidates.append(candidate) | |
# Order all candidates by score | |
ordered = sorted(all_candidates, key=lambda tup: tup[1]) | |
# Select k best | |
sequences = ordered[:beam_width] | |
return sequences | |
def beam_search_generate(text, beam_width=8, max_length=512): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
messages = [] | |
messages.append( | |
{ | |
"role": "user", "content": text | |
} | |
) | |
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) | |
model_inputs = encodeds.to(device) | |
model.to(device) | |
generated_ids = model.generate( | |
model_inputs, | |
max_new_tokens=max_length, | |
num_beams=beam_width, | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
predict_object = deFormat(decoded[0]) | |
messages.append( | |
{ | |
"role": "assistent", "content": ''.join(predict_object[1]['content']) | |
} | |
) | |
return ''.join(predict_object[1]['content']) | |
def yes_man(message, history): | |
return beam_search_generate(message) | |
gr.ChatInterface( | |
yes_man, | |
chatbot=gr.Chatbot(height=650), | |
textbox=gr.Textbox(placeholder="Write your message here ", container=False, scale=7), | |
# slider=gr.Slider(minimum=6, maximum=8, step=1, label="Beam Width"), | |
title="Sur Muy", | |
description="I am your assistant", | |
# examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"], | |
cache_examples=True, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
).launch() |