import os import gc import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead from safetensors.torch import load_file import logging from huggingface_hub import login # Set up logging login(token=os.environ.get("LA_NAME")) # Constants THRESHOLD = 2 # From Plan2Align # Initialize device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load models once print("Loading models...") model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 ) class RewardModel: def __init__(self, device, tokenizer, torch_dtype=torch.float16): self.device = device self.tokenizer = tokenizer if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Set chat template if not already set if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: # Using Llama 3's default chat template self.tokenizer.chat_template = "<|begin_of_text|>{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>\n' + message['content'] + '<|eot_id|>'}}{% endfor %}" print("Loading reward model...") self.RM = AutoModelForCausalLMWithValueHead.from_pretrained( "ray24724919/plan2align_rm", device_map={"": 0}, # Force model to stay on GPU torch_dtype=torch_dtype ) self.RM.eval() print("Reward model loaded successfully!") def _create_single_message(self, language, source, translation): return [ { "role": "system", "content": "You are a helpful translator and only output the result." }, { "role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:" }, { "role": "assistant", "content": translation } ] def _process_inputs(self, messages): try: input_ids = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt", padding=True, truncation=True ) attention_mask = torch.ones_like(input_ids) input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) if len(input_ids.shape) == 1: input_ids = input_ids.unsqueeze(0) attention_mask = attention_mask.unsqueeze(0) return { "input_ids": input_ids, "attention_mask": attention_mask } except Exception as e: logging.error(f"Error processing inputs: {str(e)}") raise def reward_fn(self, language, source, translations): try: all_rewards = [] for translation in translations: messages = self._create_single_message(language, source, translation) inputs = self._process_inputs(messages) with torch.no_grad(): outputs = self.RM(**inputs, return_value=True) rewards = outputs[2] reward = rewards[0, -1].cpu().item() all_rewards.append(reward) return all_rewards except Exception as e: logging.error(f"Error in reward_fn: {str(e)}") raise def get_len(self, language, translations): try: len_ = 0 for translation in translations: l = self.tokenizer(translation, return_tensors="pt").input_ids.to(device).shape[-1] len_ += l return len_ except Exception as e: logging.error(f"Error in get_len: {str(e)}") raise # Create reward model instance with the already loaded tokenizer reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16) print("Models loaded successfully!") # Memory management function def clear_cache(): """Clear CUDA cache and run garbage collection to free memory""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return "Cache cleared" # Helper functions from Plan2Align def rm_predict_preference(source, translation0, translation1, language="English"): translations = [translation0, translation1] for t_i in range(len(translations)): translations[t_i] = ''.join(translations[t_i]).replace('',' ') rewards = reward_model.reward_fn(language, source.replace('',' '), translations) best_index = rewards.index(max(rewards)) return best_index def rm_find_best_translation(source, translations, language="English"): copy_translations = translations.copy() if len(translations) < 2: return translations[0] if translations else None for t_i in range(len(translations)): translations[t_i] = ''.join(translations[t_i]).replace('',' ') rewards = reward_model.reward_fn(language, ''.join(source).replace('',' '), translations) print(rewards) best_index = rewards.index(max(rewards)) print(f"Total translations length = {len(translations)}, and best translation index is: {best_index}") if rewards[best_index] >= THRESHOLD: return copy_translations[best_index] else: return None def translate_chinese_to_english(chinese_text, target_language="English"): # Generate multiple translations translations = [] # Generate three different translations with different system prompts system_prompts = [ "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." ] for prompt in system_prompts: messages = [ {"role": "system", "content": prompt}, {"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"} ] inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) outputs = model.generate( inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True ) translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) translations.append(translation) # Get rewards for all translations rewards = reward_model.reward_fn(target_language, chinese_text.replace('',' '), [t.replace('',' ') for t in translations]) # Find the best translation best_index = rewards.index(max(rewards)) best_translation = translations[best_index] # Return all information return { "best_translation": best_translation, "best_reward": rewards[best_index], "all_translations": translations, "all_rewards": rewards, "best_index": best_index } # Updated process_text function with cache clearing def process_text(text, target_language="English"): if not text.strip(): return "Please enter some text to translate.", "", "", "", "" try: result = translate_chinese_to_english(text, target_language) # Format the candidate translations with their rewards candidates = [] for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])): marker = "★ " if i == result["best_index"] else "" candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n") candidates_text = "\n".join(candidates) # Clear cache after processing clear_cache() return ( result["best_translation"], f"{result['best_reward']:.4f}", candidates_text, f"Candidate {result['best_index']+1}", "Yes" if result["best_reward"] >= THRESHOLD else "No" ) except Exception as e: # Clear cache even if there's an error clear_cache() return f"Error: {str(e)}", "", "", "", "" # Define available target languages - only the supported ones target_languages = [ "English", "Russian", "German", "Japanese", "Korean" ] # Create an enhanced Gradio interface with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo: gr.Markdown("#Translation with Plan2Align") gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates.") gr.Markdown("Predictive Planning Based Test-Time Preference Alignment in Paragraph-Level Machine Translation") with gr.Row(): with gr.Column(scale=1): source_text = gr.Textbox( label="Chinese Text", placeholder="Enter Chinese text here...", lines=5 ) target_language = gr.Dropdown( choices=target_languages, value="English", label="Target Language" ) translate_button = gr.Button("Translate") clear_button = gr.Button("Clear Memory Cache") with gr.Column(scale=2): with gr.Tab("Best Translation"): best_translation = gr.Textbox( label="Best Translation", lines=5, interactive=False ) best_reward = gr.Textbox( label="Reward Score", interactive=False ) best_candidate = gr.Textbox( label="Best Candidate", interactive=False ) meets_threshold = gr.Textbox( label="Meets Quality Threshold", interactive=False ) with gr.Tab("All Candidates"): all_candidates = gr.Textbox( label="All Translation Candidates with Rewards", lines=15, interactive=False ) cache_status = gr.Textbox( label="Cache Status", value="Ready", interactive=False ) # Set up the translation flow translate_button.click( fn=process_text, inputs=[source_text, target_language], outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold] ) # Add manual cache clearing button clear_button.click( fn=clear_cache, inputs=[], outputs=[cache_status] ) # Examples with more complex sentences in Traditional Chinese about Taiwan gr.Examples( examples=[ ["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"], ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"], ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"], ["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"], ["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"] ], inputs=[source_text, target_language], outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold], fn=process_text ) gr.Markdown("## How It Works") gr.Markdown(""" 1. The system generates three different translations using different translation styles: - Literal: A word-for-word translation preserving structure - Professional: A clear, formal translation - Creative: A vivid, expressive translation 2. The reward model evaluates each translation and assigns a score 3. The translation with the highest reward score is selected as the best 4. A translation meets the quality threshold if its reward score is ≥ 2.0 """) if __name__ == "__main__": demo.launch()