|
import torch |
|
import safetensors.torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from trl import AutoModelForCausalLMWithValueHead |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch_dtype = torch.bfloat16 |
|
|
|
|
|
model_id = "meta-llama/Meta-Llama-3.1-8B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
lm_model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
device_map="auto" |
|
) |
|
|
|
|
|
RM = AutoModelForCausalLMWithValueHead.from_pretrained( |
|
'ray24724919/plan2align_rm', |
|
torch_dtype=torch_dtype, |
|
device_map="auto" |
|
) |
|
RM.eval() |
|
RM.gradient_checkpointing_enable() |
|
|
|
|
|
def load_file(file_path): |
|
return safetensors.torch.load_file(file_path) |
|
|
|
|
|
|
|
try: |
|
value_head_weights = load_file("value_head.safetensors") |
|
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()} |
|
RM.v_head.load_state_dict(new_state_dict) |
|
except FileNotFoundError: |
|
print("Value head weights file not found. Using default weights.") |
|
|
|
|
|
def translate(source_text, target_language="English", model=lm_model): |
|
""" |
|
Translate text from Chinese to the specified target language. |
|
|
|
Args: |
|
source_text (str): The Chinese text to translate |
|
target_language (str): The target language for translation |
|
model: The model to use for translation |
|
|
|
Returns: |
|
str: The translated text |
|
""" |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful translator and only output the result."}, |
|
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"} |
|
] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
return translation |
|
|
|
|
|
def evaluate_translation(source_text, translation, target_language="English"): |
|
""" |
|
Evaluate the quality of a translation using the reward model. |
|
|
|
Args: |
|
source_text (str): The original Chinese text |
|
translation (str): The translated text |
|
target_language (str): The target language of the translation |
|
|
|
Returns: |
|
float: The reward score |
|
""" |
|
messages = [ |
|
{"role": "system", "content": "You are a helpful translator and only output the result."}, |
|
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}, |
|
{"role": "assistant", "content": translation} |
|
] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = RM(input_ids=inputs.input_ids) |
|
reward_score = outputs.value.item() |
|
|
|
return reward_score |
|
|
|
|
|
def translate_and_evaluate(source_text, target_language="English"): |
|
""" |
|
Translate text and evaluate the translation quality in one step. |
|
|
|
Args: |
|
source_text (str): The Chinese text to translate |
|
target_language (str): The target language for translation |
|
|
|
Returns: |
|
tuple: (translation, reward_score) |
|
""" |
|
translation = translate(source_text, target_language) |
|
reward_score = evaluate_translation(source_text, translation, target_language) |
|
return translation, reward_score |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
source = "你好世界" |
|
translation, reward_score = translate_and_evaluate(source) |
|
print(f"Source: {source}") |
|
print(f"Translation to English: {translation}") |
|
print(f"Reward Score: {reward_score}") |
|
|
|
|
|
target_language = "French" |
|
translation, reward_score = translate_and_evaluate(source, target_language) |
|
print(f"\nSource: {source}") |
|
print(f"Translation to {target_language}: {translation}") |
|
print(f"Reward Score: {reward_score}") |
|
|
|
|
|
print("\n=== Interactive Translation Mode ===") |
|
print("Enter 'quit' to exit") |
|
while True: |
|
user_input = input("\nEnter Chinese text to translate: ") |
|
if user_input.lower() == 'quit': |
|
break |
|
|
|
target = input("Enter target language (default: English): ").strip() |
|
if not target: |
|
target = "English" |
|
|
|
translation, reward_score = translate_and_evaluate(user_input, target) |
|
print(f"Translation to {target}: {translation}") |
|
print(f"Reward Score: {reward_score}") |