hunterschep commited on
Commit
b551379
·
verified ·
1 Parent(s): 7fb93ed

remove big model

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -7,15 +7,11 @@ import unicodedata
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load the big model
11
- big_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-3.3B")
12
- big_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-3.3B").to(device)
13
-
14
  # Load the small model
15
  small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
16
  small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
17
 
18
- # Fix tokenizers
19
  def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
20
  old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
21
  tokenizer.lang_code_to_id[new_lang] = old_len - 1
@@ -28,12 +24,11 @@ def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
28
  tokenizer.added_tokens_encoder = {}
29
  tokenizer.added_tokens_decoder = {}
30
 
31
- fix_tokenizer(big_tokenizer)
32
  fix_tokenizer(small_tokenizer)
33
 
34
  # Translation function
35
- def translate(text, model_type, src_lang, tgt_lang):
36
- tokenizer, model = (big_tokenizer, big_model) if model_type == "Large" else (small_tokenizer, small_model)
37
  if src_lang == "zho_Hant":
38
  text = preproc_chinese(text)
39
  tokenizer.src_lang = src_lang
@@ -71,7 +66,6 @@ def switch_direction(src_lang, tgt_lang):
71
 
72
  with gr.Blocks() as demo:
73
  gr.Markdown("# AMIS - Chinese Translation Tool")
74
- model_type = gr.Radio(choices=["Small", "Large"], value="Small", label="Model Type")
75
  src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
76
  tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
77
  input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
@@ -79,7 +73,7 @@ with gr.Blocks() as demo:
79
  translate_btn = gr.Button("Translate")
80
  switch_btn = gr.Button("Switch Direction")
81
 
82
- translate_btn.click(translate, inputs=[input_text, model_type, src_lang, tgt_lang], outputs=output_text)
83
  switch_btn.click(switch_direction, inputs=[src_lang, tgt_lang], outputs=[src_lang, tgt_lang])
84
 
85
  if __name__ == "__main__":
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
 
 
 
10
  # Load the small model
11
  small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
12
  small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
13
 
14
+ # Fix tokenizer
15
  def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
16
  old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
17
  tokenizer.lang_code_to_id[new_lang] = old_len - 1
 
24
  tokenizer.added_tokens_encoder = {}
25
  tokenizer.added_tokens_decoder = {}
26
 
 
27
  fix_tokenizer(small_tokenizer)
28
 
29
  # Translation function
30
+ def translate(text, src_lang, tgt_lang):
31
+ tokenizer, model = small_tokenizer, small_model
32
  if src_lang == "zho_Hant":
33
  text = preproc_chinese(text)
34
  tokenizer.src_lang = src_lang
 
66
 
67
  with gr.Blocks() as demo:
68
  gr.Markdown("# AMIS - Chinese Translation Tool")
 
69
  src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
70
  tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
71
  input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
 
73
  translate_btn = gr.Button("Translate")
74
  switch_btn = gr.Button("Switch Direction")
75
 
76
+ translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
77
  switch_btn.click(switch_direction, inputs=[src_lang, tgt_lang], outputs=[src_lang, tgt_lang])
78
 
79
  if __name__ == "__main__":