[test] demo
Browse files- app.py +202 -142
- requirements.txt +5 -4
- translation_model.py +158 -0
app.py
CHANGED
@@ -1,154 +1,214 @@
|
|
1 |
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import random
|
4 |
-
|
5 |
-
# import spaces #[uncomment to use ZeroGPU]
|
6 |
-
from diffusers import DiffusionPipeline
|
7 |
import torch
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
if torch.cuda.is_available()
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
]
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
with gr.Row():
|
72 |
-
prompt = gr.Text(
|
73 |
-
label="Prompt",
|
74 |
-
show_label=False,
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
with gr.Accordion("Advanced Settings", open=False):
|
85 |
-
negative_prompt = gr.Text(
|
86 |
-
label="Negative prompt",
|
87 |
-
max_lines=1,
|
88 |
-
placeholder="Enter a negative prompt",
|
89 |
-
visible=False,
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
-
)
|
118 |
-
|
119 |
with gr.Row():
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
],
|
150 |
-
|
|
|
|
|
151 |
)
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
if __name__ == "__main__":
|
154 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from trl import AutoModelForCausalLMWithValueHead
|
5 |
+
|
6 |
+
# Set device and dtype
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
torch_dtype = torch.bfloat16
|
9 |
+
|
10 |
+
# Load models only once at startup
|
11 |
+
print("Loading models...")
|
12 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
14 |
+
lm_model = AutoModelForCausalLM.from_pretrained(
|
15 |
+
model_id,
|
16 |
+
torch_dtype=torch_dtype,
|
17 |
+
device_map="auto"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Load the reward model
|
21 |
+
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
|
22 |
+
'ray24724919/plan2align_rm',
|
23 |
+
torch_dtype=torch_dtype,
|
24 |
+
device_map="auto"
|
25 |
+
)
|
26 |
+
RM.eval()
|
27 |
+
print("Models loaded successfully!")
|
28 |
+
|
29 |
+
# Self-contained translation and evaluation functions
|
30 |
+
def translate(source_text, target_language="English"):
|
31 |
+
"""
|
32 |
+
Translate text from Chinese to the specified target language.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
source_text (str): The Chinese text to translate
|
36 |
+
target_language (str): The target language for translation
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: The translated text
|
40 |
+
"""
|
41 |
+
# Format the input as per the system prompt
|
42 |
+
messages = [
|
43 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
44 |
+
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
|
45 |
+
]
|
46 |
+
|
47 |
+
# Format messages for the model
|
48 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
49 |
+
|
50 |
+
# Tokenize the input
|
51 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
52 |
+
|
53 |
+
# Generate translation
|
54 |
+
with torch.no_grad():
|
55 |
+
outputs = lm_model.generate(
|
56 |
+
**inputs,
|
57 |
+
max_new_tokens=512,
|
58 |
+
temperature=0.7,
|
59 |
+
do_sample=True,
|
60 |
+
pad_token_id=tokenizer.eos_token_id
|
61 |
+
)
|
62 |
+
|
63 |
+
# Decode the generated text
|
64 |
+
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
|
65 |
+
return translation
|
66 |
+
|
67 |
+
def evaluate_translation(source_text, translation, target_language="English"):
|
68 |
+
"""
|
69 |
+
Evaluate the quality of a translation using the reward model.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
source_text (str): The original Chinese text
|
73 |
+
translation (str): The translated text
|
74 |
+
target_language (str): The target language of the translation
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
float: The reward score
|
78 |
+
"""
|
79 |
+
messages = [
|
80 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
81 |
+
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
|
82 |
+
{"role": "assistant", "content": translation}
|
83 |
+
]
|
84 |
+
|
85 |
+
# Format messages for the reward model
|
86 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
87 |
+
|
88 |
+
# Tokenize the input
|
89 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
90 |
+
|
91 |
+
# Get reward score
|
92 |
+
with torch.no_grad():
|
93 |
+
outputs = RM(input_ids=inputs.input_ids)
|
94 |
+
reward_score = outputs.value.item()
|
95 |
+
|
96 |
+
return reward_score
|
97 |
+
|
98 |
+
# Combined function for the Gradio interface
|
99 |
+
def translate_text(source_text, target_language):
|
100 |
+
"""
|
101 |
+
Translate text and get reward score
|
102 |
+
|
103 |
+
Args:
|
104 |
+
source_text (str): The Chinese text to translate
|
105 |
+
target_language (str): The target language for translation
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
tuple: (translation, reward_score)
|
109 |
+
"""
|
110 |
+
if not source_text.strip():
|
111 |
+
return "Please enter some text to translate.", 0.0
|
112 |
+
|
113 |
+
try:
|
114 |
+
translation = translate(source_text, target_language)
|
115 |
+
reward_score = evaluate_translation(source_text, translation, target_language)
|
116 |
+
return translation, float(reward_score)
|
117 |
+
except Exception as e:
|
118 |
+
return f"Error: {str(e)}", 0.0
|
119 |
+
|
120 |
+
# Define available target languages
|
121 |
+
target_languages = [
|
122 |
+
"English", "French", "Spanish", "German", "Italian",
|
123 |
+
"Portuguese", "Russian", "Japanese", "Korean", "Arabic"
|
124 |
]
|
125 |
|
126 |
+
# Create the Gradio interface
|
127 |
+
with gr.Blocks(title="Chinese Translation with Reward Scoring") as demo:
|
128 |
+
gr.Markdown("# Chinese to Any Language Translation")
|
129 |
+
gr.Markdown("This demo translates Chinese text to your chosen language and provides a quality score from our reward model.")
|
130 |
+
|
131 |
+
with gr.Row():
|
132 |
+
with gr.Column():
|
133 |
+
source_text = gr.Textbox(
|
134 |
+
label="Chinese Text",
|
135 |
+
placeholder="Enter Chinese text here...",
|
136 |
+
lines=5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
)
|
138 |
+
target_language = gr.Dropdown(
|
139 |
+
choices=target_languages,
|
140 |
+
value="English",
|
141 |
+
label="Target Language"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
)
|
143 |
+
translate_button = gr.Button("Translate")
|
144 |
+
|
145 |
+
with gr.Column():
|
146 |
+
translation_output = gr.Textbox(
|
147 |
+
label="Translation",
|
148 |
+
lines=5,
|
149 |
+
interactive=False
|
150 |
)
|
151 |
+
reward_score = gr.Number(
|
152 |
+
label="Translation Quality Score (higher is better)",
|
153 |
+
precision=4,
|
154 |
+
interactive=False
|
155 |
+
)
|
156 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
with gr.Row():
|
158 |
+
score_indicator = gr.Label(label="Quality Rating")
|
159 |
+
|
160 |
+
# Function to update the quality rating based on score
|
161 |
+
def update_quality_rating(score):
|
162 |
+
if score >= 0.8:
|
163 |
+
return "Excellent"
|
164 |
+
elif score >= 0.6:
|
165 |
+
return "Good"
|
166 |
+
elif score >= 0.4:
|
167 |
+
return "Average"
|
168 |
+
elif score >= 0.2:
|
169 |
+
return "Poor"
|
170 |
+
else:
|
171 |
+
return "Very Poor"
|
172 |
+
|
173 |
+
# Set up the translation flow
|
174 |
+
translate_outputs = translate_button.click(
|
175 |
+
fn=translate_text,
|
176 |
+
inputs=[source_text, target_language],
|
177 |
+
outputs=[translation_output, reward_score]
|
178 |
+
)
|
179 |
+
|
180 |
+
# Update the quality rating whenever the reward score changes
|
181 |
+
reward_score.change(
|
182 |
+
fn=update_quality_rating,
|
183 |
+
inputs=[reward_score],
|
184 |
+
outputs=[score_indicator]
|
185 |
+
)
|
186 |
+
|
187 |
+
# Examples
|
188 |
+
gr.Examples(
|
189 |
+
examples=[
|
190 |
+
["你好,世界!", "English"],
|
191 |
+
["我喜欢学习新的语言。", "Spanish"],
|
192 |
+
["北京烤鴨很好吃。", "French"],
|
193 |
+
["人工智能正在改变世界。", "German"],
|
194 |
+
["今天天气真好。", "Japanese"]
|
195 |
],
|
196 |
+
inputs=[source_text, target_language],
|
197 |
+
outputs=[translation_output, reward_score],
|
198 |
+
fn=translate_text
|
199 |
)
|
200 |
+
|
201 |
+
gr.Markdown("## How It Works")
|
202 |
+
gr.Markdown("""
|
203 |
+
1. Enter Chinese text in the input box
|
204 |
+
2. Select your desired target language
|
205 |
+
3. Click 'Translate' to get the translation
|
206 |
+
4. The system will display the translation and a quality score
|
207 |
+
|
208 |
+
The quality score is generated by a reward model trained to evaluate translation quality.
|
209 |
+
Higher scores indicate better translations.
|
210 |
+
""")
|
211 |
+
|
212 |
+
# Launch the app
|
213 |
if __name__ == "__main__":
|
214 |
demo.launch()
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
accelerate
|
2 |
-
|
3 |
-
|
4 |
-
torch
|
5 |
-
transformers
|
|
|
6 |
xformers
|
|
|
1 |
accelerate
|
2 |
+
gradio
|
3 |
+
safetensors
|
4 |
+
torch>=2.0.0
|
5 |
+
transformers>=4.30.0
|
6 |
+
trl>=0.7.1
|
7 |
xformers
|
translation_model.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import safetensors.torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from trl import AutoModelForCausalLMWithValueHead
|
5 |
+
|
6 |
+
# Set device and dtype
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
torch_dtype = torch.bfloat16
|
9 |
+
|
10 |
+
# Load the base LLaMa 3.1 8B model for translation
|
11 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B" # Replace with your actual model ID
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
13 |
+
lm_model = AutoModelForCausalLM.from_pretrained(
|
14 |
+
model_id,
|
15 |
+
torch_dtype=torch_dtype,
|
16 |
+
device_map="auto"
|
17 |
+
)
|
18 |
+
|
19 |
+
# Load the reward model
|
20 |
+
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
|
21 |
+
'ray24724919/plan2align_rm',
|
22 |
+
torch_dtype=torch_dtype,
|
23 |
+
device_map="auto"
|
24 |
+
)
|
25 |
+
RM.eval()
|
26 |
+
RM.gradient_checkpointing_enable() # if needed for memory efficiency
|
27 |
+
|
28 |
+
# Define the load_file function
|
29 |
+
def load_file(file_path):
|
30 |
+
return safetensors.torch.load_file(file_path)
|
31 |
+
|
32 |
+
# Load value head weights if you have the file
|
33 |
+
# If you don't have the specific file, you might need to download it or use the model as is
|
34 |
+
try:
|
35 |
+
value_head_weights = load_file("value_head.safetensors") # Replace with actual path
|
36 |
+
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
|
37 |
+
RM.v_head.load_state_dict(new_state_dict)
|
38 |
+
except FileNotFoundError:
|
39 |
+
print("Value head weights file not found. Using default weights.")
|
40 |
+
|
41 |
+
# Define translation function with more flexibility
|
42 |
+
def translate(source_text, target_language="English", model=lm_model):
|
43 |
+
"""
|
44 |
+
Translate text from Chinese to the specified target language.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
source_text (str): The Chinese text to translate
|
48 |
+
target_language (str): The target language for translation
|
49 |
+
model: The model to use for translation
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
str: The translated text
|
53 |
+
"""
|
54 |
+
# Format the input as per the system prompt
|
55 |
+
messages = [
|
56 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
57 |
+
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"}
|
58 |
+
]
|
59 |
+
|
60 |
+
# Format messages for the model
|
61 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
62 |
+
|
63 |
+
# Tokenize the input
|
64 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
65 |
+
|
66 |
+
# Generate translation
|
67 |
+
with torch.no_grad():
|
68 |
+
outputs = model.generate(
|
69 |
+
**inputs,
|
70 |
+
max_new_tokens=512,
|
71 |
+
temperature=0.7,
|
72 |
+
do_sample=True,
|
73 |
+
pad_token_id=tokenizer.eos_token_id
|
74 |
+
)
|
75 |
+
|
76 |
+
# Decode the generated text
|
77 |
+
translation = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
|
78 |
+
return translation
|
79 |
+
|
80 |
+
# Evaluate the translation using the reward model
|
81 |
+
def evaluate_translation(source_text, translation, target_language="English"):
|
82 |
+
"""
|
83 |
+
Evaluate the quality of a translation using the reward model.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
source_text (str): The original Chinese text
|
87 |
+
translation (str): The translated text
|
88 |
+
target_language (str): The target language of the translation
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
float: The reward score
|
92 |
+
"""
|
93 |
+
messages = [
|
94 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
95 |
+
{"role": "user", "content": f"### Translate this from Chinese to {target_language}, Chinese:\n{source_text}\n### {target_language}:"},
|
96 |
+
{"role": "assistant", "content": translation}
|
97 |
+
]
|
98 |
+
|
99 |
+
# Format messages for the reward model
|
100 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
101 |
+
|
102 |
+
# Tokenize the input
|
103 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
104 |
+
|
105 |
+
# Get reward score
|
106 |
+
with torch.no_grad():
|
107 |
+
outputs = RM(input_ids=inputs.input_ids)
|
108 |
+
reward_score = outputs.value.item()
|
109 |
+
|
110 |
+
return reward_score
|
111 |
+
|
112 |
+
# Function to translate and evaluate in one step
|
113 |
+
def translate_and_evaluate(source_text, target_language="English"):
|
114 |
+
"""
|
115 |
+
Translate text and evaluate the translation quality in one step.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
source_text (str): The Chinese text to translate
|
119 |
+
target_language (str): The target language for translation
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
tuple: (translation, reward_score)
|
123 |
+
"""
|
124 |
+
translation = translate(source_text, target_language)
|
125 |
+
reward_score = evaluate_translation(source_text, translation, target_language)
|
126 |
+
return translation, reward_score
|
127 |
+
|
128 |
+
# Example usage
|
129 |
+
if __name__ == "__main__":
|
130 |
+
# Example with default target language (English)
|
131 |
+
source = "你好世界"
|
132 |
+
translation, reward_score = translate_and_evaluate(source)
|
133 |
+
print(f"Source: {source}")
|
134 |
+
print(f"Translation to English: {translation}")
|
135 |
+
print(f"Reward Score: {reward_score}")
|
136 |
+
|
137 |
+
# Example with custom target language
|
138 |
+
target_language = "French"
|
139 |
+
translation, reward_score = translate_and_evaluate(source, target_language)
|
140 |
+
print(f"\nSource: {source}")
|
141 |
+
print(f"Translation to {target_language}: {translation}")
|
142 |
+
print(f"Reward Score: {reward_score}")
|
143 |
+
|
144 |
+
# Interactive mode
|
145 |
+
print("\n=== Interactive Translation Mode ===")
|
146 |
+
print("Enter 'quit' to exit")
|
147 |
+
while True:
|
148 |
+
user_input = input("\nEnter Chinese text to translate: ")
|
149 |
+
if user_input.lower() == 'quit':
|
150 |
+
break
|
151 |
+
|
152 |
+
target = input("Enter target language (default: English): ").strip()
|
153 |
+
if not target:
|
154 |
+
target = "English"
|
155 |
+
|
156 |
+
translation, reward_score = translate_and_evaluate(user_input, target)
|
157 |
+
print(f"Translation to {target}: {translation}")
|
158 |
+
print(f"Reward Score: {reward_score}")
|