huckiyang commited on
Commit
ddcd34b
·
1 Parent(s): 0966210

polish the demo

Browse files
Files changed (1) hide show
  1. app.py +32 -5
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -124,6 +125,14 @@ class RewardModel:
124
  reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16)
125
  print("Models loaded successfully!")
126
 
 
 
 
 
 
 
 
 
127
  # Helper functions from Plan2Align
128
  def rm_predict_preference(source, translation0, translation1, language="English"):
129
  translations = [translation0, translation1]
@@ -202,7 +211,7 @@ def translate_chinese_to_english(chinese_text, target_language="English"):
202
  "best_index": best_index
203
  }
204
 
205
- # Updated Gradio interface
206
  def process_text(text, target_language="English"):
207
  if not text.strip():
208
  return "Please enter some text to translate.", "", "", "", ""
@@ -218,6 +227,9 @@ def process_text(text, target_language="English"):
218
 
219
  candidates_text = "\n".join(candidates)
220
 
 
 
 
221
  return (
222
  result["best_translation"],
223
  f"{result['best_reward']:.4f}",
@@ -226,6 +238,8 @@ def process_text(text, target_language="English"):
226
  "Yes" if result["best_reward"] >= THRESHOLD else "No"
227
  )
228
  except Exception as e:
 
 
229
  return f"Error: {str(e)}", "", "", "", ""
230
 
231
  # Define available target languages - only the supported ones
@@ -234,9 +248,8 @@ target_languages = [
234
  ]
235
 
236
  # Create an enhanced Gradio interface
237
- with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
238
- gr.Markdown("# Chinese Translation with Plan2Align")
239
- gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates.")
240
 
241
  with gr.Row():
242
  with gr.Column(scale=1):
@@ -251,6 +264,7 @@ with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
251
  label="Target Language"
252
  )
253
  translate_button = gr.Button("Translate")
 
254
 
255
  with gr.Column(scale=2):
256
  with gr.Tab("Best Translation"):
@@ -278,6 +292,12 @@ with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
278
  lines=15,
279
  interactive=False
280
  )
 
 
 
 
 
 
281
 
282
  # Set up the translation flow
283
  translate_button.click(
@@ -286,7 +306,14 @@ with gr.Blocks(title="Chinese Translation with Plan2Align") as demo:
286
  outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold]
287
  )
288
 
289
- # Examples with more complex sentences in Traditional Chinese about Taiwan for the supported languages
 
 
 
 
 
 
 
290
  gr.Examples(
291
  examples=[
292
  ["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"],
 
1
  import os
2
+ import gc
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
125
  reward_model = RewardModel(device, tokenizer, torch_dtype=torch.float16)
126
  print("Models loaded successfully!")
127
 
128
+ # Memory management function
129
+ def clear_cache():
130
+ """Clear CUDA cache and run garbage collection to free memory"""
131
+ if torch.cuda.is_available():
132
+ torch.cuda.empty_cache()
133
+ gc.collect()
134
+ return "Cache cleared"
135
+
136
  # Helper functions from Plan2Align
137
  def rm_predict_preference(source, translation0, translation1, language="English"):
138
  translations = [translation0, translation1]
 
211
  "best_index": best_index
212
  }
213
 
214
+ # Updated process_text function with cache clearing
215
  def process_text(text, target_language="English"):
216
  if not text.strip():
217
  return "Please enter some text to translate.", "", "", "", ""
 
227
 
228
  candidates_text = "\n".join(candidates)
229
 
230
+ # Clear cache after processing
231
+ clear_cache()
232
+
233
  return (
234
  result["best_translation"],
235
  f"{result['best_reward']:.4f}",
 
238
  "Yes" if result["best_reward"] >= THRESHOLD else "No"
239
  )
240
  except Exception as e:
241
+ # Clear cache even if there's an error
242
+ clear_cache()
243
  return f"Error: {str(e)}", "", "", "", ""
244
 
245
  # Define available target languages - only the supported ones
 
248
  ]
249
 
250
  # Create an enhanced Gradio interface
251
+ with gr.Blocks(title="Test-Time Machine Translation with Plan2Align")
252
+ gr.Markdown("This demo uses the Plan2Align approach to translate Chinese text to your chosen language, showing how the reward model evaluates different translation candidates. Paper: https://arxiv.org/pdf/2502.20795.")
 
253
 
254
  with gr.Row():
255
  with gr.Column(scale=1):
 
264
  label="Target Language"
265
  )
266
  translate_button = gr.Button("Translate")
267
+ clear_button = gr.Button("Clear Memory Cache")
268
 
269
  with gr.Column(scale=2):
270
  with gr.Tab("Best Translation"):
 
292
  lines=15,
293
  interactive=False
294
  )
295
+
296
+ cache_status = gr.Textbox(
297
+ label="Cache Status",
298
+ value="Ready",
299
+ interactive=False
300
+ )
301
 
302
  # Set up the translation flow
303
  translate_button.click(
 
306
  outputs=[best_translation, best_reward, all_candidates, best_candidate, meets_threshold]
307
  )
308
 
309
+ # Add manual cache clearing button
310
+ clear_button.click(
311
+ fn=clear_cache,
312
+ inputs=[],
313
+ outputs=[cache_status]
314
+ )
315
+
316
+ # Examples with more complex sentences in Traditional Chinese about Taiwan
317
  gr.Examples(
318
  examples=[
319
  ["夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "English"],