huckiyang commited on
Commit
0966210
·
1 Parent(s): 3a7f01f

polish the demo

Browse files
Files changed (1) hide show
  1. app.py +126 -18
app.py CHANGED
@@ -155,7 +155,7 @@ def rm_find_best_translation(source, translations, language="English"):
155
  else:
156
  return None
157
 
158
- def translate_chinese_to_english(chinese_text):
159
  # Generate multiple translations
160
  translations = []
161
 
@@ -169,7 +169,7 @@ def translate_chinese_to_english(chinese_text):
169
  for prompt in system_prompts:
170
  messages = [
171
  {"role": "system", "content": prompt},
172
- {"role": "user", "content": f"Translate the following Chinese text to English:\n\n{chinese_text}"}
173
  ]
174
 
175
  inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
@@ -185,26 +185,134 @@ def translate_chinese_to_english(chinese_text):
185
  translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
186
  translations.append(translation)
187
 
188
- # Use reward model to find the best translation
189
- best_translation = rm_find_best_translation(chinese_text, translations)
 
190
 
191
- if best_translation is None:
192
- # If no translation meets the threshold, return the first one
193
- return translations[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- return best_translation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Gradio interface
198
- def process_text(text):
199
- return translate_chinese_to_english(text)
 
200
 
201
- demo = gr.Interface(
202
- fn=process_text,
203
- inputs=gr.Textbox(lines=5, placeholder="Enter Chinese text here..."),
204
- outputs=gr.Textbox(lines=5),
205
- title="Chinese to English Translation with Plan2Align",
206
- description="This app uses the Plan2Align approach to translate Chinese text to English."
207
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
  demo.launch()
 
155
  else:
156
  return None
157
 
158
+ def translate_chinese_to_english(chinese_text, target_language="English"):
159
  # Generate multiple translations
160
  translations = []
161
 
 
169
  for prompt in system_prompts:
170
  messages = [
171
  {"role": "system", "content": prompt},
172
+ {"role": "user", "content": f"Translate the following Chinese text to {target_language}:\n\n{chinese_text}"}
173
  ]
174
 
175
  inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
 
185
  translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
186
  translations.append(translation)
187
 
188
+ # Get rewards for all translations
189
+ rewards = reward_model.reward_fn(target_language, chinese_text.replace('</s>',' '),
190
+ [t.replace('</s>',' ') for t in translations])
191
 
192
+ # Find the best translation
193
+ best_index = rewards.index(max(rewards))
194
+ best_translation = translations[best_index]
195
+
196
+ # Return all information
197
+ return {
198
+ "best_translation": best_translation,
199
+ "best_reward": rewards[best_index],
200
+ "all_translations": translations,
201
+ "all_rewards": rewards,
202
+ "best_index": best_index
203
+ }
204
+
205
+ # Updated Gradio interface
206
+ def process_text(text, target_language="English"):
207
+ if not text.strip():
208
+ return "Please enter some text to translate.", "", "", "", ""
209
 
210
+ try:
211
+ result = translate_chinese_to_english(text, target_language)
212
+
213
+ # Format the candidate translations with their rewards
214
+ candidates = []
215
+ for i, (trans, reward) in enumerate(zip(result["all_translations"], result["all_rewards"])):
216
+ marker = "★ " if i == result["best_index"] else ""
217
+ candidates.append(f"{marker}Candidate {i+1} (Reward: {reward:.4f}):\n{trans}\n")
218
+
219
+ candidates_text = "\n".join(candidates)
220
+
221
+ return (
222
+ result["best_translation"],
223
+ f"{result['best_reward']:.4f}",
224
+ candidates_text,
225
+ f"Candidate {result['best_index']+1}",
226
+ "Yes" if result["best_reward"] >= THRESHOLD else "No"
227
+ )
228
+ except Exception as e:
229
+ return f"Error: {str(e)}", "", "", "", ""
230
 
231
+ # Define available target languages - only the supported ones
232
+ target_languages = [
233
+ "English", "Russian", "German", "Japanese", "Korean"
234
+ ]
235
 
236
+ # Create an enhanced Gradio interface
237
+ with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
238
+ gr.Markdown("# Chinese Translation with Plan2Align")
239
+ gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates.")
240
+
241
+ with gr.Row():
242
+ with gr.Column(scale=1):
243
+ source_text = gr.Textbox(
244
+ label="Chinese Text",
245
+ placeholder="Enter Chinese text here...",
246
+ lines=5
247
+ )
248
+ target_language = gr.Dropdown(
249
+ choices=target_languages,
250
+ value="English",
251
+ label="Target Language"
252
+ )
253
+ translate_button = gr.Button("Translate")
254
+
255
+ with gr.Column(scale=2):
256
+ with gr.Tab("Best Translation"):
257
+ best_translation = gr.Textbox(
258
+ label="Best Translation",
259
+ lines=5,
260
+ interactive=False
261
+ )
262
+ best_reward = gr.Textbox(
263
+ label="Reward Score",
264
+ interactive=False
265
+ )
266
+ best_candidate = gr.Textbox(
267
+ label="Best Candidate",
268
+ interactive=False
269
+ )
270
+ meets_threshold = gr.Textbox(
271
+ label="Meets Quality Threshold",
272
+ interactive=False
273
+ )
274
+
275
+ with gr.Tab("All Candidates"):
276
+ all_candidates = gr.Textbox(
277
+ label="All Translation Candidates with Rewards",
278
+ lines=15,
279
+ interactive=False
280
+ )
281
+
282
+ # Set up the translation flow
283
+ translate_button.click(
284
+ fn=process_text,
285
+ inputs=[source_text, target_language],
286
+ outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold]
287
+ )
288
+
289
+ # Examples with more complex sentences in Traditional Chinese about Taiwan for the supported languages
290
+ gr.Examples(
291
+ examples=[
292
+ ["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"],
293
+ ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Russian"],
294
+ ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "German"],
295
+ ["珍珠奶茶起源於台灣,現已成為全球流行的飲品,展現了飲食文化對世界的影響力。", "Japanese"],
296
+ ["原住民文化擁有豐富的傳統和藝術表現形式,包括歌舞、編織和木雕,反映了與自然和諧共處的生活智慧。", "Korean"]
297
+ ],
298
+ inputs=[source_text, target_language],
299
+ outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold],
300
+ fn=process_text
301
+ )
302
+
303
+ gr.Markdown("## How It Works")
304
+ gr.Markdown("""
305
+ 1. The system generates three different translations using different translation styles:
306
+ - Literal: A word-for-word translation preserving structure
307
+ - Professional: A clear, formal translation
308
+ - Creative: A vivid, expressive translation
309
+
310
+ 2. The reward model evaluates each translation and assigns a score
311
+
312
+ 3. The translation with the highest reward score is selected as the best
313
+
314
+ 4. A translation meets the quality threshold if its reward score is ≥ 2.0
315
+ """)
316
 
317
  if __name__ == "__main__":
318
  demo.launch()