SurMuy / app.py
AingHongsin's picture
Update app.py
7a916ca verified
raw
history blame
4.87 kB
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import torch
import pprint
import json
import gradio as gr
# Read and preprocess data
TURN_TEMPLATE = "<|im_start|>{role}\n{content}<eos>\n"
TURN_PREFIX = "<|im_start|>{role}\n"
start_token = "<|im_start|>"
end_token = "<eos>"
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' πŸ€”
# Load your fine-tuned model and tokenizer
surMuy_model_id = "AingHongsin/SurMuy_v1_512512201"
model = AutoModelForCausalLM.from_pretrained(surMuy_model_id,
device_map={'': 0},
revision="main",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(surMuy_model_id)
model.eval()
model.to(zero.device)
def deFormat(data):
# Find the start and end indices of each turn in the data
turn_indices = []
start_index = data.find(start_token)
while start_index != -1:
end_index = data.find(end_token, start_index)
if end_index != -1:
turn_indices.append((start_index, end_index + len(end_token)))
else:
turn_indices.append((start_index, len(data)))
start_index = data.find(start_token, start_index + len(start_token))
# Extract role and content for each turn
turns = []
for i in range(len(turn_indices)):
turn_start, turn_end = turn_indices[i]
turn_data = data[turn_start:turn_end].strip()
# Extract role and content from turn data using TURN_TEMPLATE
role_start = len(start_token)
role_end = turn_data.find("\n", role_start)
role = turn_data[role_start:role_end]
content_start = role_end + 1
content = turn_data[content_start:]
turns.append({'role': role, 'content': content})
return turns
@spaces.GPU
def generate(text):
device = zero.device
messages = [
{"role": "user", "content": text}
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
# print(tokenizer.convert_ids_to_tokens(encodeds[0]))
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=tokenizer.pad_token_id)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predict_answer = deFormat(decoded[0])
return predict_answer
@spaces.GPU
def beam_search(model, start_token, beam_width=3, max_length=10):
sequences = [[start_token, 0.0]] # Initialize with start_token and score 0.0
while len(sequences[0][0]) < max_length:
all_candidates = []
for seq, score in sequences:
if seq[-1] == '<end>': # Assuming '<end>' is the end token
all_candidates.append((seq, score))
continue
next_token_probs = model.predict_next(seq)
for token, prob in enumerate(next_token_probs):
candidate = (seq + [token], score - np.log(prob))
all_candidates.append(candidate)
# Order all candidates by score
ordered = sorted(all_candidates, key=lambda tup: tup[1])
# Select k best
sequences = ordered[:beam_width]
return sequences
@spaces.GPU
def beam_search_generate(text, beam_width=8, max_length=512):
device = "cuda" if torch.cuda.is_available() else "cpu"
messages = []
messages.append(
{
"role": "user", "content": text
}
)
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(
model_inputs,
max_new_tokens=max_length,
num_beams=beam_width,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id
)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predict_object = deFormat(decoded[0])
messages.append(
{
"role": "assistent", "content": ''.join(predict_object[1]['content'])
}
)
return ''.join(predict_object[1]['content'])
def yes_man(message, history):
return beam_search_generate(message)
gr.ChatInterface(
yes_man,
chatbot=gr.Chatbot(height=650),
textbox=gr.Textbox(placeholder="Write your message here ", container=False, scale=7),
# slider=gr.Slider(minimum=6, maximum=8, step=1, label="Beam Width"),
title="Sur Muy",
description="I am your assistant",
# examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
cache_examples=True,
undo_btn="Delete Previous",
clear_btn="Clear",
).launch()