huckiyang's picture
re-fix
d3515b8
raw
history blame
13.2 kB
import os
import gc
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from safetensors.torch import load_file
import logging
from huggingface_hub import login
# Set up logging
login(token=os.environ.get("LA_NAME"))
# Constants
THRESHOLD = 2 # From Plan2Align
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
class RewardModel:
def __init__(self, device, tokenizer, torch_dtype=torch.float16):
self.device = device
self.tokenizer = tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Set chat template if not already set
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
# Using Llama 3's default chat template
self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}"
print("Loading reward model...")
self.RM = AutoModelForCausalLMWithValueHead.from_pretrained(
"ray24724919/plan2align_rm",
device_map={"": 0}, # Force model to stay on GPU
torch_dtype=torch_dtype
)
self.RM.eval()
print("Reward model loaded successfully!")
def _create_single_message(self, language, source, translation):
return [
{
"role": "system",
"content": "You are a helpful translator and only output the result."
},
{
"role": "user",
"content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"
},
{
"role": "assistant",
"content": translation
}
]
def _process_inputs(self, messages):
try:
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=False,
return_tensors="pt",
padding=True,
truncation=True
)
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
except Exception as e:
logging.error(f"Error processing inputs: {str(e)}")
raise
def reward_fn(self, language, source, translations):
try:
all_rewards = []
for translation in translations:
messages = self._create_single_message(language, source, translation)
inputs = self._process_inputs(messages)
with torch.no_grad():
outputs = self.RM(**inputs, return_value=True)
rewards = outputs[2]
reward = rewards[0, -1].cpu().item()
all_rewards.append(reward)
return all_rewards
except Exception as e:
logging.error(f"Error in reward_fn: {str(e)}")
raise
def get_len(self, language, translations):
try:
len_ = 0
for translation in translations:
l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1]
len_ += l
return len_
except Exception as e:
logging.error(f"Error in get_len: {str(e)}")
raise
# Create reward model instance with the already loaded tokenizer
reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16)
print("Models loaded successfully!")
# Memory management function
def clear_cache():
"""Clear CUDA cache and run garbage collection to free memory"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return "Cache cleared"
# Helper functions from Plan2Align
def rm_predict_preference(source, translation0, translation1, language="English"):
translations = [translation0, translation1]
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, source.replace('</s>',' '), translations)
best_index = rewards.index(max(rewards))
return best_index
def rm_find_best_translation(source, translations, language="English"):
copy_translations = translations.copy()
if len(translations) < 2:
return translations[0] if translations else None
for t_i in range(len(translations)):
translations[t_i] = ''.join(translations[t_i]).replace('</s>',' ')
rewards = reward_model.reward_fn(language, ''.join(source).replace('</s>',' '), translations)
print(rewards)
best_index = rewards.index(max(rewards))
print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}")
if rewards[best_index] >= THRESHOLD:
return copy_translations[best_index]
else:
return None
def translate_chinese_to_english(chinese_text, target_language="English"):
# Generate multiple translations
translations = []
# Generate three different translations with different system prompts
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
for prompt in system_prompts:
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
translations.append(translation)
# Get rewards for all translations
rewards = reward_model.reward_fn(target_language, chinese_text.replace('</s>',' '),
[t.replace('</s>',' ') for t in translations])
# Find the best translation
best_index = rewards.index(max(rewards))
best_translation = translations[best_index]
# Return all information
return {
"best_translation": best_translation,
"best_reward": rewards[best_index],
"all_translations": translations,
"all_rewards": rewards,
"best_index": best_index
}
# Updated process_text function with cache clearing
def process_text(text, target_language="English"):
if not text.strip():
return "Please enter some text to translate.", "", "", "", ""
try:
result = translate_chinese_to_english(text, target_language)
# Format the candidate translations with their rewards
candidates = []
for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])):
marker = "★ " if i == result["best_index"] else ""
candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n")
candidates_text = "\n".join(candidates)
# Clear cache after processing
clear_cache()
return (
result["best_translation"],
f"{result['best_reward']:.4f}",
candidates_text,
f"Candidate {result['best_index']+1}",
"Yes" if result["best_reward"] >= THRESHOLD else "No"
)
except Exception as e:
# Clear cache even if there's an error
clear_cache()
return f"Error: {str(e)}", "", "", "", ""
# Define available target languages - only the supported ones
target_languages = [
"English", "Russian", "German", "Japanese", "Korean"
]
# Create an enhanced Gradio interface
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align")
gr.Markdown("Showing how the reward model evaluates different translation candidates. Plan2Align 2025. Predictive Planning Based Test-Time Preference Alignmentin Paragraph-Level Machine Translation")
with gr.Row():
with gr.Column(scale=1):
source_text = gr.Textbox(
label="Chinese Text",
placeholder="Enter Chinese text here...",
lines=5
)
target_language = gr.Dropdown(
choices=target_languages,
value="English",
label="Target Language"
)
translate_button = gr.Button("Translate")
clear_button = gr.Button("Clear Memory Cache")
with gr.Column(scale=2):
with gr.Tab("Best Translation"):
best_translation = gr.Textbox(
label="Best Translation",
lines=5,
interactive=False
)
best_reward = gr.Textbox(
label="Reward Score",
interactive=False
)
best_candidate = gr.Textbox(
label="Best Candidate",
interactive=False
)
meets_threshold = gr.Textbox(
label="Meets Quality Threshold",
interactive=False
)
with gr.Tab("All Candidates"):
all_candidates = gr.Textbox(
label="All Translation Candidates with Rewards",
lines=15,
interactive=False
)
cache_status = gr.Textbox(
label="Cache Status",
value="Ready",
interactive=False
)
# Set up the translation flow
translate_button.click(
fn=process_text,
inputs=[source_text, target_language],
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold]
)
# Add manual cache clearing button
clear_button.click(
fn=clear_cache,
inputs=[],
outputs=[cache_status]
)
# Examples with more complex sentences in Traditional Chinese about Taiwan
gr.Examples(
examples=[
["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"],
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"],
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"],
["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"],
["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"]
],
inputs=[source_text, target_language],
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold],
fn=process_text
)
gr.Markdown("## How It Works")
gr.Markdown("""
1. The system generates three different translations using different translation styles:
- Literal: A word-for-word translation preserving structure
- Professional: A clear, formal translation
- Creative: A vivid, expressive translation
2. The reward model evaluates each translation and assigns a score
3. The translation with the highest reward score is selected as the best
4. A translation meets the quality threshold if its reward score is ≥ 2.0
""")
if __name__ == "__main__":
demo.launch()