import torch import safetensors.torch from transformers import AutoTokenizer, AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead # Set device and dtype device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.bfloat16 # Load the base LLaMa 3.1 8B model for translation model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID tokenizer = AutoTokenizer.from_pretrained(model_id) lm_model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch_dtype, device_map="auto" ) # Load the reward model RM = AutoModelForCausalLMWithValueHead.from_pretrained( 'ray24724919/plan2align_rm', torch_dtype=torch_dtype, device_map="auto" ) RM.eval() RM.gradient_checkpointing_enable() # if needed for memory efficiency # Define the load_file function def load_file(file_path): return safetensors.torch.load_file(file_path) # Load value head weights if you have the file # If you don't have the specific file, you might need to download it or use the model as is try: value_head_weights = load_file("value_head.safetensors") # Replace with actual path new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()} RM.v_head.load_state_dict(new_state_dict) except FileNotFoundError: print("Value head weights file not found. Using default weights.") # Define translation function with more flexibility def translate(source_text, target_language="English", model=lm_model): """ Translate text from Chinese to the specified target language. Args: source_text (str): The Chinese text to translate target_language (str): The target language for translation model: The model to use for translation Returns: str: The translated text """ # Format the input as per the system prompt messages = [ {"role": "system", "content": "You are a helpful translator and only output the result."}, {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"} ] # Format messages for the model prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Tokenize the input inputs = tokenizer(prompt, return_tensors="pt").to(device) # Generate translation with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode the generated text translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() return translation # Evaluate the translation using the reward model def evaluate_translation(source_text, translation, target_language="English"): """ Evaluate the quality of a translation using the reward model. Args: source_text (str): The original Chinese text translation (str): The translated text target_language (str): The target language of the translation Returns: float: The reward score """ messages = [ {"role": "system", "content": "You are a helpful translator and only output the result."}, {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}, {"role": "assistant", "content": translation} ] # Format messages for the reward model prompt = tokenizer.apply_chat_template(messages, tokenize=False) # Tokenize the input inputs = tokenizer(prompt, return_tensors="pt").to(device) # Get reward score with torch.no_grad(): outputs = RM(input_ids=inputs.input_ids) reward_score = outputs.value.item() return reward_score # Function to translate and evaluate in one step def translate_and_evaluate(source_text, target_language="English"): """ Translate text and evaluate the translation quality in one step. Args: source_text (str): The Chinese text to translate target_language (str): The target language for translation Returns: tuple: (translation, reward_score) """ translation = translate(source_text, target_language) reward_score = evaluate_translation(source_text, translation, target_language) return translation, reward_score # Example usage if __name__ == "__main__": # Example with default target language (English) source = "你好世界" translation, reward_score = translate_and_evaluate(source) print(f"Source: {source}") print(f"Translation to English: {translation}") print(f"Reward Score: {reward_score}") # Example with custom target language target_language = "French" translation, reward_score = translate_and_evaluate(source, target_language) print(f"\nSource: {source}") print(f"Translation to {target_language}: {translation}") print(f"Reward Score: {reward_score}") # Interactive mode print("\n=== Interactive Translation Mode ===") print("Enter 'quit' to exit") while True: user_input = input("\nEnter Chinese text to translate: ") if user_input.lower() == 'quit': break target = input("Enter target language (default: English): ").strip() if not target: target = "English" translation, reward_score = translate_and_evaluate(user_input, target) print(f"Translation to {target}: {translation}") print(f"Reward Score: {reward_score}")