Spaces:
Runtime error
Runtime error
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import pandas as pd | |
import torch | |
import pprint | |
import json | |
import gradio as gr | |
# Read and preprocess data | |
TURN_TEMPLATE = "<|im_start|>{role}\n{content}<eos>\n" | |
TURN_PREFIX = "<|im_start|>{role}\n" | |
start_token = "<|im_start|>" | |
end_token = "<eos>" | |
zero = torch.Tensor([0]).cuda() | |
print(zero.device) # <-- 'cpu' π€ | |
# Load your fine-tuned model and tokenizer | |
surMuy_model_id = "AingHongsin/SurMuy_v1_512512201" | |
model = AutoModelForCausalLM.from_pretrained(surMuy_model_id, | |
revision="main", | |
torch_dtype=torch.bfloat16, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(surMuy_model_id) | |
model.eval() | |
model.to(zero.device) | |
def deFormat(data): | |
# Find the start and end indices of each turn in the data | |
turn_indices = [] | |
start_index = data.find(start_token) | |
while start_index != -1: | |
end_index = data.find(end_token, start_index) | |
if end_index != -1: | |
turn_indices.append((start_index, end_index + len(end_token))) | |
else: | |
turn_indices.append((start_index, len(data))) | |
start_index = data.find(start_token, start_index + len(start_token)) | |
# Extract role and content for each turn | |
turns = [] | |
for i in range(len(turn_indices)): | |
turn_start, turn_end = turn_indices[i] | |
turn_data = data[turn_start:turn_end].strip() | |
# Extract role and content from turn data using TURN_TEMPLATE | |
role_start = len(start_token) | |
role_end = turn_data.find("\n", role_start) | |
role = turn_data[role_start:role_end] | |
content_start = role_end + 1 | |
content = turn_data[content_start:] | |
turns.append({'role': role, 'content': content}) | |
return turns | |
def beam_search_generate(text, beam_width=8, max_length=512): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
messages = [ | |
{ | |
"role": "user", "content": text | |
} | |
] | |
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) | |
model_inputs = encodeds.to(device) | |
model.to(device) | |
generated_ids = model.generate( | |
model_inputs, | |
max_new_tokens=max_length, | |
num_beams=beam_width, | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
predict_object = deFormat(decoded[0]) | |
return ''.join(predict_object[1]['content']) | |
def yes_man(message, history): | |
return beam_search_generate(message) | |
gr.ChatInterface( | |
yes_man, | |
chatbot=gr.Chatbot(height=580), | |
textbox=gr.Textbox(placeholder="Write your message here ", container=False, scale=7), | |
title="SurMuy", | |
description="Language Models for a Khmer Q&A System", | |
cache_examples=True, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
).launch() |