polish the demo
Browse files
app.py
CHANGED
@@ -155,7 +155,7 @@ def rm_find_best_translation(source, translations, language="English"):
|
|
155 |
else:
|
156 |
return None
|
157 |
|
158 |
-
def translate_chinese_to_english(chinese_text):
|
159 |
# Generate multiple translations
|
160 |
translations = []
|
161 |
|
@@ -169,7 +169,7 @@ def translate_chinese_to_english(chinese_text):
|
|
169 |
for prompt in system_prompts:
|
170 |
messages = [
|
171 |
{"role": "system", "content": prompt},
|
172 |
-
{"role": "user", "content": f"Translate the following Chinese text to
|
173 |
]
|
174 |
|
175 |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
@@ -185,26 +185,134 @@ def translate_chinese_to_english(chinese_text):
|
|
185 |
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
186 |
translations.append(translation)
|
187 |
|
188 |
-
#
|
189 |
-
|
|
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
#
|
198 |
-
|
199 |
-
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
demo.launch()
|
|
|
155 |
else:
|
156 |
return None
|
157 |
|
158 |
+
def translate_chinese_to_english(chinese_text, target_language="English"):
|
159 |
# Generate multiple translations
|
160 |
translations = []
|
161 |
|
|
|
169 |
for prompt in system_prompts:
|
170 |
messages = [
|
171 |
{"role": "system", "content": prompt},
|
172 |
+
{"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"}
|
173 |
]
|
174 |
|
175 |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
|
|
185 |
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
186 |
translations.append(translation)
|
187 |
|
188 |
+
# Get rewards for all translations
|
189 |
+
rewards = reward_model.reward_fn(target_language, chinese_text.replace('</s>',' '),
|
190 |
+
[t.replace('</s>',' ') for t in translations])
|
191 |
|
192 |
+
# Find the best translation
|
193 |
+
best_index = rewards.index(max(rewards))
|
194 |
+
best_translation = translations[best_index]
|
195 |
+
|
196 |
+
# Return all information
|
197 |
+
return {
|
198 |
+
"best_translation": best_translation,
|
199 |
+
"best_reward": rewards[best_index],
|
200 |
+
"all_translations": translations,
|
201 |
+
"all_rewards": rewards,
|
202 |
+
"best_index": best_index
|
203 |
+
}
|
204 |
+
|
205 |
+
# Updated Gradio interface
|
206 |
+
def process_text(text, target_language="English"):
|
207 |
+
if not text.strip():
|
208 |
+
return "Please enter some text to translate.", "", "", "", ""
|
209 |
|
210 |
+
try:
|
211 |
+
result = translate_chinese_to_english(text, target_language)
|
212 |
+
|
213 |
+
# Format the candidate translations with their rewards
|
214 |
+
candidates = []
|
215 |
+
for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])):
|
216 |
+
marker = "★ " if i == result["best_index"] else ""
|
217 |
+
candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n")
|
218 |
+
|
219 |
+
candidates_text = "\n".join(candidates)
|
220 |
+
|
221 |
+
return (
|
222 |
+
result["best_translation"],
|
223 |
+
f"{result['best_reward']:.4f}",
|
224 |
+
candidates_text,
|
225 |
+
f"Candidate {result['best_index']+1}",
|
226 |
+
"Yes" if result["best_reward"] >= THRESHOLD else "No"
|
227 |
+
)
|
228 |
+
except Exception as e:
|
229 |
+
return f"Error: {str(e)}", "", "", "", ""
|
230 |
|
231 |
+
# Define available target languages - only the supported ones
|
232 |
+
target_languages = [
|
233 |
+
"English", "Russian", "German", "Japanese", "Korean"
|
234 |
+
]
|
235 |
|
236 |
+
# Create an enhanced Gradio interface
|
237 |
+
with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
|
238 |
+
gr.Markdown("# Chinese Translation with Plan2Align")
|
239 |
+
gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates.")
|
240 |
+
|
241 |
+
with gr.Row():
|
242 |
+
with gr.Column(scale=1):
|
243 |
+
source_text = gr.Textbox(
|
244 |
+
label="Chinese Text",
|
245 |
+
placeholder="Enter Chinese text here...",
|
246 |
+
lines=5
|
247 |
+
)
|
248 |
+
target_language = gr.Dropdown(
|
249 |
+
choices=target_languages,
|
250 |
+
value="English",
|
251 |
+
label="Target Language"
|
252 |
+
)
|
253 |
+
translate_button = gr.Button("Translate")
|
254 |
+
|
255 |
+
with gr.Column(scale=2):
|
256 |
+
with gr.Tab("Best Translation"):
|
257 |
+
best_translation = gr.Textbox(
|
258 |
+
label="Best Translation",
|
259 |
+
lines=5,
|
260 |
+
interactive=False
|
261 |
+
)
|
262 |
+
best_reward = gr.Textbox(
|
263 |
+
label="Reward Score",
|
264 |
+
interactive=False
|
265 |
+
)
|
266 |
+
best_candidate = gr.Textbox(
|
267 |
+
label="Best Candidate",
|
268 |
+
interactive=False
|
269 |
+
)
|
270 |
+
meets_threshold = gr.Textbox(
|
271 |
+
label="Meets Quality Threshold",
|
272 |
+
interactive=False
|
273 |
+
)
|
274 |
+
|
275 |
+
with gr.Tab("All Candidates"):
|
276 |
+
all_candidates = gr.Textbox(
|
277 |
+
label="All Translation Candidates with Rewards",
|
278 |
+
lines=15,
|
279 |
+
interactive=False
|
280 |
+
)
|
281 |
+
|
282 |
+
# Set up the translation flow
|
283 |
+
translate_button.click(
|
284 |
+
fn=process_text,
|
285 |
+
inputs=[source_text, target_language],
|
286 |
+
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold]
|
287 |
+
)
|
288 |
+
|
289 |
+
# Examples with more complex sentences in Traditional Chinese about Taiwan for the supported languages
|
290 |
+
gr.Examples(
|
291 |
+
examples=[
|
292 |
+
["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"],
|
293 |
+
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"],
|
294 |
+
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"],
|
295 |
+
["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"],
|
296 |
+
["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"]
|
297 |
+
],
|
298 |
+
inputs=[source_text, target_language],
|
299 |
+
outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold],
|
300 |
+
fn=process_text
|
301 |
+
)
|
302 |
+
|
303 |
+
gr.Markdown("## How It Works")
|
304 |
+
gr.Markdown("""
|
305 |
+
1. The system generates three different translations using different translation styles:
|
306 |
+
- Literal: A word-for-word translation preserving structure
|
307 |
+
- Professional: A clear, formal translation
|
308 |
+
- Creative: A vivid, expressive translation
|
309 |
+
|
310 |
+
2. The reward model evaluates each translation and assigns a score
|
311 |
+
|
312 |
+
3. The translation with the highest reward score is selected as the best
|
313 |
+
|
314 |
+
4. A translation meets the quality threshold if its reward score is ≥ 2.0
|
315 |
+
""")
|
316 |
|
317 |
if __name__ == "__main__":
|
318 |
demo.launch()
|