File size: 5,810 Bytes
a024afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import safetensors.torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead

# Set device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.bfloat16

# Load the base LLaMa 3.1 8B model for translation
model_id = "meta-llama/Meta-Llama-3.1-8B"  # Replace with your actual model ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
lm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    device_map="auto"
)

# Load the reward model
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
    'ray24724919/plan2align_rm',
    torch_dtype=torch_dtype,
    device_map="auto"
)
RM.eval()
RM.gradient_checkpointing_enable()  # if needed for memory efficiency

# Define the load_file function
def load_file(file_path):
    return safetensors.torch.load_file(file_path)

# Load value head weights if you have the file
# If you don't have the specific file, you might need to download it or use the model as is
try:
    value_head_weights = load_file("value_head.safetensors")  # Replace with actual path
    new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
    RM.v_head.load_state_dict(new_state_dict)
except FileNotFoundError:
    print("Value head weights file not found. Using default weights.")

# Define translation function with more flexibility
def translate(source_text, target_language="English", model=lm_model):
    """
    Translate text from Chinese to the specified target language.
    
    Args:
        source_text (str): The Chinese text to translate
        target_language (str): The target language for translation
        model: The model to use for translation
        
    Returns:
        str: The translated text
    """
    # Format the input as per the system prompt
    messages = [
        {"role": "system", "content": "You are a helpful translator and only output the result."},
        {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
    ]
    
    # Format messages for the model
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate translation
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode the generated text
    translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
    return translation

# Evaluate the translation using the reward model
def evaluate_translation(source_text, translation, target_language="English"):
    """
    Evaluate the quality of a translation using the reward model.
    
    Args:
        source_text (str): The original Chinese text
        translation (str): The translated text
        target_language (str): The target language of the translation
        
    Returns:
        float: The reward score
    """
    messages = [
        {"role": "system", "content": "You are a helpful translator and only output the result."},
        {"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
        {"role": "assistant", "content": translation}
    ]
    
    # Format messages for the reward model
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Get reward score
    with torch.no_grad():
        outputs = RM(input_ids=inputs.input_ids)
        reward_score = outputs.value.item()
    
    return reward_score

# Function to translate and evaluate in one step
def translate_and_evaluate(source_text, target_language="English"):
    """
    Translate text and evaluate the translation quality in one step.
    
    Args:
        source_text (str): The Chinese text to translate
        target_language (str): The target language for translation
        
    Returns:
        tuple: (translation, reward_score)
    """
    translation = translate(source_text, target_language)
    reward_score = evaluate_translation(source_text, translation, target_language)
    return translation, reward_score

# Example usage
if __name__ == "__main__":
    # Example with default target language (English)
    source = "你好世界"
    translation, reward_score = translate_and_evaluate(source)
    print(f"Source: {source}")
    print(f"Translation to English: {translation}")
    print(f"Reward Score: {reward_score}")
    
    # Example with custom target language
    target_language = "French"
    translation, reward_score = translate_and_evaluate(source, target_language)
    print(f"\nSource: {source}")
    print(f"Translation to {target_language}: {translation}")
    print(f"Reward Score: {reward_score}")
    
    # Interactive mode
    print("\n=== Interactive Translation Mode ===")
    print("Enter 'quit' to exit")
    while True:
        user_input = input("\nEnter Chinese text to translate: ")
        if user_input.lower() == 'quit':
            break
            
        target = input("Enter target language (default: English): ").strip()
        if not target:
            target = "English"
            
        translation, reward_score = translate_and_evaluate(user_input, target)
        print(f"Translation to {target}: {translation}")
        print(f"Reward Score: {reward_score}")