|
import os |
|
import gc |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from trl import AutoModelForCausalLMWithValueHead |
|
from safetensors.torch import load_file |
|
import logging |
|
from huggingface_hub import login |
|
|
|
|
|
login(token=os.environ.get("LA_NAME")) |
|
|
|
|
|
THRESHOLD = 2 |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
print("Loading models...") |
|
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
class RewardModel: |
|
def __init__(self, device, tokenizer, torch_dtype=torch.float16): |
|
self.device = device |
|
self.tokenizer = tokenizer |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: |
|
|
|
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}" |
|
|
|
print("Loading reward model...") |
|
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained( |
|
"ray24724919/plan2align_rm", |
|
device_map={"": 0}, |
|
torch_dtype=torch_dtype |
|
) |
|
self.RM.eval() |
|
print("Reward model loaded successfully!") |
|
|
|
def _create_single_message(self, language, source, translation): |
|
return [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful translator and only output the result." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:" |
|
}, |
|
{ |
|
"role": "assistant", |
|
"content": translation |
|
} |
|
] |
|
|
|
def _process_inputs(self, messages): |
|
try: |
|
input_ids = self.tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=False, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True |
|
) |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
input_ids = input_ids.to(self.device) |
|
attention_mask = attention_mask.to(self.device) |
|
|
|
if len(input_ids.shape) == 1: |
|
input_ids = input_ids.unsqueeze(0) |
|
attention_mask = attention_mask.unsqueeze(0) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing inputs: {str(e)}") |
|
raise |
|
|
|
def reward_fn(self, language, source, translations): |
|
try: |
|
all_rewards = [] |
|
for translation in translations: |
|
messages = self._create_single_message(language, source, translation) |
|
inputs = self._process_inputs(messages) |
|
with torch.no_grad(): |
|
outputs = self.RM(**inputs, return_value=True) |
|
rewards = outputs[2] |
|
reward = rewards[0, -1].cpu().item() |
|
all_rewards.append(reward) |
|
return all_rewards |
|
except Exception as e: |
|
logging.error(f"Error in reward_fn: {str(e)}") |
|
raise |
|
|
|
def get_len(self, language, translations): |
|
try: |
|
len_ = 0 |
|
for translation in translations: |
|
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1] |
|
len_ += l |
|
return len_ |
|
except Exception as e: |
|
logging.error(f"Error in get_len: {str(e)}") |
|
raise |
|
|
|
|
|
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16) |
|
print("Models loaded successfully!") |
|
|
|
|
|
def clear_cache(): |
|
"""Clear CUDA cache and run garbage collection to free memory""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return "Cache cleared" |
|
|
|
|
|
def rm_predict_preference(source, translation0, translation1, language="English"): |
|
translations = [translation0, translation1] |
|
for t_i in range(len(translations)): |
|
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') |
|
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations) |
|
best_index = rewards.index(max(rewards)) |
|
return best_index |
|
|
|
def rm_find_best_translation(source, translations, language="English"): |
|
copy_translations = translations.copy() |
|
|
|
if len(translations) < 2: |
|
return translations[0] if translations else None |
|
|
|
for t_i in range(len(translations)): |
|
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ') |
|
|
|
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations) |
|
|
|
print(rewards) |
|
|
|
best_index = rewards.index(max(rewards)) |
|
|
|
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}") |
|
|
|
if rewards[best_index] >= THRESHOLD: |
|
return copy_translations[best_index] |
|
else: |
|
return None |
|
|
|
def translate_chinese_to_english(chinese_text, target_language="English"): |
|
|
|
translations = [] |
|
|
|
|
|
system_prompts = [ |
|
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", |
|
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", |
|
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." |
|
] |
|
|
|
for prompt in system_prompts: |
|
messages = [ |
|
{"role": "system", "content": prompt}, |
|
{"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"} |
|
] |
|
|
|
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) |
|
|
|
outputs = model.generate( |
|
inputs, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True |
|
) |
|
|
|
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) |
|
translations.append(translation) |
|
|
|
|
|
rewards = reward_model.reward_fn(target_language, chinese_text.replace('</s>',' '), |
|
[t.replace('</s>',' ') for t in translations]) |
|
|
|
|
|
best_index = rewards.index(max(rewards)) |
|
best_translation = translations[best_index] |
|
|
|
|
|
return { |
|
"best_translation": best_translation, |
|
"best_reward": rewards[best_index], |
|
"all_translations": translations, |
|
"all_rewards": rewards, |
|
"best_index": best_index |
|
} |
|
|
|
|
|
def process_text(text, target_language="English"): |
|
if not text.strip(): |
|
return "Please enter some text to translate.", "", "", "", "" |
|
|
|
try: |
|
result = translate_chinese_to_english(text, target_language) |
|
|
|
|
|
candidates = [] |
|
for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])): |
|
marker = "★ " if i == result["best_index"] else "" |
|
candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n") |
|
|
|
candidates_text = "\n".join(candidates) |
|
|
|
|
|
clear_cache() |
|
|
|
return ( |
|
result["best_translation"], |
|
f"{result['best_reward']:.4f}", |
|
candidates_text, |
|
f"Candidate {result['best_index']+1}", |
|
"Yes" if result["best_reward"] >= THRESHOLD else "No" |
|
) |
|
except Exception as e: |
|
|
|
clear_cache() |
|
return f"Error: {str(e)}", "", "", "", "" |
|
|
|
|
|
target_languages = [ |
|
"English", "Russian", "German", "Japanese", "Korean" |
|
] |
|
|
|
|
|
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") |
|
gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates. Paper: https://arxiv.org/pdf/2502.20795.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
source_text = gr.Textbox( |
|
label="Chinese Text", |
|
placeholder="Enter Chinese text here...", |
|
lines=5 |
|
) |
|
target_language = gr.Dropdown( |
|
choices=target_languages, |
|
value="English", |
|
label="Target Language" |
|
) |
|
translate_button = gr.Button("Translate") |
|
clear_button = gr.Button("Clear Memory Cache") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Tab("Best Translation"): |
|
best_translation = gr.Textbox( |
|
label="Best Translation", |
|
lines=5, |
|
interactive=False |
|
) |
|
best_reward = gr.Textbox( |
|
label="Reward Score", |
|
interactive=False |
|
) |
|
best_candidate = gr.Textbox( |
|
label="Best Candidate", |
|
interactive=False |
|
) |
|
meets_threshold = gr.Textbox( |
|
label="Meets Quality Threshold", |
|
interactive=False |
|
) |
|
|
|
with gr.Tab("All Candidates"): |
|
all_candidates = gr.Textbox( |
|
label="All Translation Candidates with Rewards", |
|
lines=15, |
|
interactive=False |
|
) |
|
|
|
cache_status = gr.Textbox( |
|
label="Cache Status", |
|
value="Ready", |
|
interactive=False |
|
) |
|
|
|
|
|
translate_button.click( |
|
fn=process_text, |
|
inputs=[source_text, target_language], |
|
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold] |
|
) |
|
|
|
|
|
clear_button.click( |
|
fn=clear_cache, |
|
inputs=[], |
|
outputs=[cache_status] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"], |
|
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"], |
|
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"], |
|
["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"], |
|
["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"] |
|
], |
|
inputs=[source_text, target_language], |
|
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold], |
|
fn=process_text |
|
) |
|
|
|
gr.Markdown("## How It Works") |
|
gr.Markdown(""" |
|
1. The system generates three different translations using different translation styles: |
|
- Literal: A word-for-word translation preserving structure |
|
- Professional: A clear, formal translation |
|
- Creative: A vivid, expressive translation |
|
|
|
2. The reward model evaluates each translation and assigns a score |
|
|
|
3. The translation with the highest reward score is selected as the best |
|
|
|
4. A translation meets the quality threshold if its reward score is ≥ 2.0 |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|