Test-Time-Translation-LLM-Demo / translation_model.py
huckiyang's picture
[test] demo
a024afa
raw
history blame
5.81 kB
import torch
import safetensors.torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
# Set device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.bfloat16
# Load the base LLaMa 3.1 8B model for translation
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto"
)
# Load the reward model
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
'ray24724919/plan2align_rm',
torch_dtype=torch_dtype,
device_map="auto"
)
RM.eval()
RM.gradient_checkpointing_enable() # if needed for memory efficiency
# Define the load_file function
def load_file(file_path):
return safetensors.torch.load_file(file_path)
# Load value head weights if you have the file
# If you don't have the specific file, you might need to download it or use the model as is
try:
value_head_weights = load_file("value_head.safetensors") # Replace with actual path
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
RM.v_head.load_state_dict(new_state_dict)
except FileNotFoundError:
print("Value head weights file not found. Using default weights.")
# Define translation function with more flexibility
def translate(source_text, target_language="English", model=lm_model):
"""
Translate text from Chinese to the specified target language.
Args:
source_text (str): The Chinese text to translate
target_language (str): The target language for translation
model: The model to use for translation
Returns:
str: The translated text
"""
# Format the input as per the system prompt
messages = [
{"role": "system", "content": "You are a helpful translator and only output the result."},
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
]
# Format messages for the model
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate translation
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated text
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
return translation
# Evaluate the translation using the reward model
def evaluate_translation(source_text, translation, target_language="English"):
"""
Evaluate the quality of a translation using the reward model.
Args:
source_text (str): The original Chinese text
translation (str): The translated text
target_language (str): The target language of the translation
Returns:
float: The reward score
"""
messages = [
{"role": "system", "content": "You are a helpful translator and only output the result."},
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
{"role": "assistant", "content": translation}
]
# Format messages for the reward model
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Get reward score
with torch.no_grad():
outputs = RM(input_ids=inputs.input_ids)
reward_score = outputs.value.item()
return reward_score
# Function to translate and evaluate in one step
def translate_and_evaluate(source_text, target_language="English"):
"""
Translate text and evaluate the translation quality in one step.
Args:
source_text (str): The Chinese text to translate
target_language (str): The target language for translation
Returns:
tuple: (translation, reward_score)
"""
translation = translate(source_text, target_language)
reward_score = evaluate_translation(source_text, translation, target_language)
return translation, reward_score
# Example usage
if __name__ == "__main__":
# Example with default target language (English)
source = "你好世界"
translation, reward_score = translate_and_evaluate(source)
print(f"Source: {source}")
print(f"Translation to English: {translation}")
print(f"Reward Score: {reward_score}")
# Example with custom target language
target_language = "French"
translation, reward_score = translate_and_evaluate(source, target_language)
print(f"\nSource: {source}")
print(f"Translation to {target_language}: {translation}")
print(f"Reward Score: {reward_score}")
# Interactive mode
print("\n=== Interactive Translation Mode ===")
print("Enter 'quit' to exit")
while True:
user_input = input("\nEnter Chinese text to translate: ")
if user_input.lower() == 'quit':
break
target = input("Enter target language (default: English): ").strip()
if not target:
target = "English"
translation, reward_score = translate_and_evaluate(user_input, target)
print(f"Translation to {target}: {translation}")
print(f"Reward Score: {reward_score}")