Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import gc | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from trl import AutoModelForCausalLMWithValueHead | |
from safetensors.torch import load_file | |
import logging | |
from huggingface_hub import login | |
# Set up logging | |
login(token=os.environ.get("LA_NAME")) | |
# Constants | |
THRESHOLD = 2 # From Plan2Align | |
# Initialize device | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load models once | |
print("Loading models...") | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
class RewardModel: | |
def __init__(self, device, tokenizer, torch_dtype=torch.float16): | |
self.device = device | |
self.tokenizer = tokenizer | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Set chat template if not already set | |
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: | |
# Using Llama 3's default chat template | |
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}" | |
print("Loading reward model...") | |
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained( | |
"ray24724919/plan2align_rm", | |
device_map={"": 0}, # Force model to stay on GPU | |
torch_dtype=torch_dtype | |
) | |
self.RM.eval() | |
print("Reward model loaded successfully!") | |
def _create_single_message(self, language, source, translation): | |
return [ | |
{ | |
"role": "system", | |
"content": "You are a helpful translator and only output the result." | |
}, | |
{ | |
"role": "user", | |
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:" | |
}, | |
{ | |
"role": "assistant", | |
"content": translation | |
} | |
] | |
def _process_inputs(self, messages): | |
try: | |
input_ids = self.tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=False, | |
return_tensors="pt", | |
padding=True, | |
truncation=True | |
) | |
attention_mask = torch.ones_like(input_ids) | |
input_ids = input_ids.to(self.device) | |
attention_mask = attention_mask.to(self.device) | |
if len(input_ids.shape) == 1: | |
input_ids = input_ids.unsqueeze(0) | |
attention_mask = attention_mask.unsqueeze(0) | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
} | |
except Exception as e: | |
logging.error(f"Error processing inputs: {str(e)}") | |
raise | |
def reward_fn(self, language, source, translations): | |
try: | |
all_rewards = [] | |
for translation in translations: | |
messages = self._create_single_message(language, source, translation) | |
inputs = self._process_inputs(messages) | |
with torch.no_grad(): | |
outputs = self.RM(**inputs, return_value=True) | |
rewards = outputs[2] | |
reward = rewards[0, -1].cpu().item() | |
all_rewards.append(reward) | |
return all_rewards | |
except Exception as e: | |
logging.error(f"Error in reward_fn: {str(e)}") | |
raise | |
def get_len(self, language, translations): | |
try: | |
len_ = 0 | |
for translation in translations: | |
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1] | |
len_ += l | |
return len_ | |
except Exception as e: | |
logging.error(f"Error in get_len: {str(e)}") | |
raise | |
# Create reward model instance with the already loaded tokenizer | |
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16) | |
print("Models loaded successfully!") | |
# Memory management function | |
def clear_cache(): | |
"""Clear CUDA cache and run garbage collection to free memory""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
return "Cache cleared" | |
# Helper functions from Plan2Align | |
def rm_predict_preference(source, translation0, translation1, language="English"): | |
translations = [translation0, translation1] | |
for t_i in range(len(translations)): | |
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') | |
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations) | |
best_index = rewards.index(max(rewards)) | |
return best_index | |
def rm_find_best_translation(source, translations, language="English"): | |
copy_translations = translations.copy() | |
if len(translations) < 2: | |
return translations[0] if translations else None | |
for t_i in range(len(translations)): | |
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') | |
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations) | |
print(rewards) | |
best_index = rewards.index(max(rewards)) | |
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}") | |
if rewards[best_index] >= THRESHOLD: | |
return copy_translations[best_index] | |
else: | |
return None | |
def translate_chinese_to_english(chinese_text, target_language="English"): | |
# Generate multiple translations | |
translations = [] | |
# Generate three different translations with different system prompts | |
system_prompts = [ | |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." | |
] | |
for prompt in system_prompts: | |
messages = [ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"} | |
] | |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
translations.append(translation) | |
# Get rewards for all translations | |
rewards = reward_model.reward_fn(target_language, chinese_text.replace('</s>',' '), | |
[t.replace('</s>',' ') for t in translations]) | |
# Find the best translation | |
best_index = rewards.index(max(rewards)) | |
best_translation = translations[best_index] | |
# Return all information | |
return { | |
"best_translation": best_translation, | |
"best_reward": rewards[best_index], | |
"all_translations": translations, | |
"all_rewards": rewards, | |
"best_index": best_index | |
} | |
# Updated process_text function with cache clearing | |
def process_text(text, target_language="English"): | |
if not text.strip(): | |
return "Please enter some text to translate.", "", "", "", "" | |
try: | |
result = translate_chinese_to_english(text, target_language) | |
# Format the candidate translations with their rewards | |
candidates = [] | |
for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])): | |
marker = "★ " if i == result["best_index"] else "" | |
candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n") | |
candidates_text = "\n".join(candidates) | |
# Clear cache after processing | |
clear_cache() | |
return ( | |
result["best_translation"], | |
f"{result['best_reward']:.4f}", | |
candidates_text, | |
f"Candidate {result['best_index']+1}", | |
"Yes" if result["best_reward"] >= THRESHOLD else "No" | |
) | |
except Exception as e: | |
# Clear cache even if there's an error | |
clear_cache() | |
return f"Error: {str(e)}", "", "", "", "" | |
# Define available target languages - only the supported ones | |
target_languages = [ | |
"English", "Russian", "German", "Japanese", "Korean" | |
] | |
# Create an enhanced Gradio interface | |
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") | |
gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates. Paper: https://arxiv.org/pdf/2502.20795.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
source_text = gr.Textbox( | |
label="Chinese Text", | |
placeholder="Enter Chinese text here...", | |
lines=5 | |
) | |
target_language = gr.Dropdown( | |
choices=target_languages, | |
value="English", | |
label="Target Language" | |
) | |
translate_button = gr.Button("Translate") | |
clear_button = gr.Button("Clear Memory Cache") | |
with gr.Column(scale=2): | |
with gr.Tab("Best Translation"): | |
best_translation = gr.Textbox( | |
label="Best Translation", | |
lines=5, | |
interactive=False | |
) | |
best_reward = gr.Textbox( | |
label="Reward Score", | |
interactive=False | |
) | |
best_candidate = gr.Textbox( | |
label="Best Candidate", | |
interactive=False | |
) | |
meets_threshold = gr.Textbox( | |
label="Meets Quality Threshold", | |
interactive=False | |
) | |
with gr.Tab("All Candidates"): | |
all_candidates = gr.Textbox( | |
label="All Translation Candidates with Rewards", | |
lines=15, | |
interactive=False | |
) | |
cache_status = gr.Textbox( | |
label="Cache Status", | |
value="Ready", | |
interactive=False | |
) | |
# Set up the translation flow | |
translate_button.click( | |
fn=process_text, | |
inputs=[source_text, target_language], | |
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold] | |
) | |
# Add manual cache clearing button | |
clear_button.click( | |
fn=clear_cache, | |
inputs=[], | |
outputs=[cache_status] | |
) | |
# Examples with more complex sentences in Traditional Chinese about Taiwan | |
gr.Examples( | |
examples=[ | |
["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"], | |
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"], | |
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"], | |
["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"], | |
["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"] | |
], | |
inputs=[source_text, target_language], | |
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold], | |
fn=process_text | |
) | |
gr.Markdown("## How It Works") | |
gr.Markdown(""" | |
1. The system generates three different translations using different translation styles: | |
- Literal: A word-for-word translation preserving structure | |
- Professional: A clear, formal translation | |
- Creative: A vivid, expressive translation | |
2. The reward model evaluates each translation and assigns a score | |
3. The translation with the highest reward score is selected as the best | |
4. A translation meets the quality threshold if its reward score is ≥ 2.0 | |
""") | |
if __name__ == "__main__": | |
demo.launch() | |