import spaces from transformers import AutoTokenizer, AutoModelForCausalLM import pandas as pd import torch import pprint import json import gradio as gr # Read and preprocess data TURN_TEMPLATE = "<|im_start|>{role}\n{content}\n" TURN_PREFIX = "<|im_start|>{role}\n" start_token = "<|im_start|>" end_token = "" zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cpu' 🤔 # Load your fine-tuned model and tokenizer surMuy_model_id = "AingHongsin/SurMuy_v1_512512201" model = AutoModelForCausalLM.from_pretrained(surMuy_model_id, revision="main", torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(surMuy_model_id) model.eval() model.to(zero.device) def deFormat(data): # Find the start and end indices of each turn in the data turn_indices = [] start_index = data.find(start_token) while start_index != -1: end_index = data.find(end_token, start_index) if end_index != -1: turn_indices.append((start_index, end_index + len(end_token))) else: turn_indices.append((start_index, len(data))) start_index = data.find(start_token, start_index + len(start_token)) # Extract role and content for each turn turns = [] for i in range(len(turn_indices)): turn_start, turn_end = turn_indices[i] turn_data = data[turn_start:turn_end].strip() # Extract role and content from turn data using TURN_TEMPLATE role_start = len(start_token) role_end = turn_data.find("\n", role_start) role = turn_data[role_start:role_end] content_start = role_end + 1 content = turn_data[content_start:] turns.append({'role': role, 'content': content}) return turns @spaces.GPU(duration=90) def beam_search_generate(text, beam_width=8, max_length=512): device = "cuda" if torch.cuda.is_available() else "cpu" messages = [ { "role": "user", "content": text } ] encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) model_inputs = encodeds.to(device) model.to(device) generated_ids = model.generate( model_inputs, max_new_tokens=max_length, num_beams=beam_width, early_stopping=True, pad_token_id=tokenizer.pad_token_id ) decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) predict_object = deFormat(decoded[0]) return ''.join(predict_object[1]['content']) def yes_man(message, history): return beam_search_generate(message) gr.ChatInterface( yes_man, chatbot=gr.Chatbot(height=580), textbox=gr.Textbox(placeholder="Write your message here ", container=False, scale=7), title="SurMuy", description="Language Models for a Khmer Q&A System", cache_examples=True, undo_btn="Delete Previous", clear_btn="Clear", ).launch()