SurMuy / app.py
AingHongsin's picture
Update app.py
9c903ec verified
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import torch
import pprint
import json
import gradio as gr
# Read and preprocess data
TURN_TEMPLATE = "<|im_start|>{role}\n{content}<eos>\n"
TURN_PREFIX = "<|im_start|>{role}\n"
start_token = "<|im_start|>"
end_token = "<eos>"
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cpu' πŸ€”
# Load your fine-tuned model and tokenizer
surMuy_model_id = "AingHongsin/SurMuy_v1_512512201"
model = AutoModelForCausalLM.from_pretrained(surMuy_model_id,
revision="main",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(surMuy_model_id)
model.eval()
model.to(zero.device)
def deFormat(data):
# Find the start and end indices of each turn in the data
turn_indices = []
start_index = data.find(start_token)
while start_index != -1:
end_index = data.find(end_token, start_index)
if end_index != -1:
turn_indices.append((start_index, end_index + len(end_token)))
else:
turn_indices.append((start_index, len(data)))
start_index = data.find(start_token, start_index + len(start_token))
# Extract role and content for each turn
turns = []
for i in range(len(turn_indices)):
turn_start, turn_end = turn_indices[i]
turn_data = data[turn_start:turn_end].strip()
# Extract role and content from turn data using TURN_TEMPLATE
role_start = len(start_token)
role_end = turn_data.find("\n", role_start)
role = turn_data[role_start:role_end]
content_start = role_end + 1
content = turn_data[content_start:]
turns.append({'role': role, 'content': content})
return turns
@spaces.GPU(duration=90)
def beam_search_generate(text, beam_width=8, max_length=512):
device = "cuda" if torch.cuda.is_available() else "cpu"
messages = [
{
"role": "user", "content": text
}
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
model_inputs = encodeds.to(device)
model.to(device)
generated_ids = model.generate(
model_inputs,
max_new_tokens=max_length,
num_beams=beam_width,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id
)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predict_object = deFormat(decoded[0])
return ''.join(predict_object[1]['content'])
def yes_man(message, history):
return beam_search_generate(message)
gr.ChatInterface(
yes_man,
chatbot=gr.Chatbot(height=580),
textbox=gr.Textbox(placeholder="Write your message here ", container=False, scale=7),
title="SurMuy",
description="Language Models for a Khmer Q&A System",
cache_examples=True,
undo_btn="Delete Previous",
clear_btn="Clear",
).launch()