File size: 5,810 Bytes
a024afa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import safetensors.torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
# Set device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.bfloat16
# Load the base LLaMa 3.1 8B model for translation
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto"
)
# Load the reward model
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
'ray24724919/plan2align_rm',
torch_dtype=torch_dtype,
device_map="auto"
)
RM.eval()
RM.gradient_checkpointing_enable() # if needed for memory efficiency
# Define the load_file function
def load_file(file_path):
return safetensors.torch.load_file(file_path)
# Load value head weights if you have the file
# If you don't have the specific file, you might need to download it or use the model as is
try:
value_head_weights = load_file("value_head.safetensors") # Replace with actual path
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
RM.v_head.load_state_dict(new_state_dict)
except FileNotFoundError:
print("Value head weights file not found. Using default weights.")
# Define translation function with more flexibility
def translate(source_text, target_language="English", model=lm_model):
"""
Translate text from Chinese to the specified target language.
Args:
source_text (str): The Chinese text to translate
target_language (str): The target language for translation
model: The model to use for translation
Returns:
str: The translated text
"""
# Format the input as per the system prompt
messages = [
{"role": "system", "content": "You are a helpful translator and only output the result."},
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
]
# Format messages for the model
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate translation
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated text
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
return translation
# Evaluate the translation using the reward model
def evaluate_translation(source_text, translation, target_language="English"):
"""
Evaluate the quality of a translation using the reward model.
Args:
source_text (str): The original Chinese text
translation (str): The translated text
target_language (str): The target language of the translation
Returns:
float: The reward score
"""
messages = [
{"role": "system", "content": "You are a helpful translator and only output the result."},
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
{"role": "assistant", "content": translation}
]
# Format messages for the reward model
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Get reward score
with torch.no_grad():
outputs = RM(input_ids=inputs.input_ids)
reward_score = outputs.value.item()
return reward_score
# Function to translate and evaluate in one step
def translate_and_evaluate(source_text, target_language="English"):
"""
Translate text and evaluate the translation quality in one step.
Args:
source_text (str): The Chinese text to translate
target_language (str): The target language for translation
Returns:
tuple: (translation, reward_score)
"""
translation = translate(source_text, target_language)
reward_score = evaluate_translation(source_text, translation, target_language)
return translation, reward_score
# Example usage
if __name__ == "__main__":
# Example with default target language (English)
source = "你好世界"
translation, reward_score = translate_and_evaluate(source)
print(f"Source: {source}")
print(f"Translation to English: {translation}")
print(f"Reward Score: {reward_score}")
# Example with custom target language
target_language = "French"
translation, reward_score = translate_and_evaluate(source, target_language)
print(f"\nSource: {source}")
print(f"Translation to {target_language}: {translation}")
print(f"Reward Score: {reward_score}")
# Interactive mode
print("\n=== Interactive Translation Mode ===")
print("Enter 'quit' to exit")
while True:
user_input = input("\nEnter Chinese text to translate: ")
if user_input.lower() == 'quit':
break
target = input("Enter target language (default: English): ").strip()
if not target:
target = "English"
translation, reward_score = translate_and_evaluate(user_input, target)
print(f"Translation to {target}: {translation}")
print(f"Reward Score: {reward_score}") |