Spaces:
Sleeping
Sleeping
remove big model
Browse files
app.py
CHANGED
@@ -7,15 +7,11 @@ import unicodedata
|
|
7 |
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
10 |
-
# Load the big model
|
11 |
-
big_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-3.3B")
|
12 |
-
big_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-3.3B").to(device)
|
13 |
-
|
14 |
# Load the small model
|
15 |
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
|
16 |
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
|
17 |
|
18 |
-
# Fix
|
19 |
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
|
20 |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
|
21 |
tokenizer.lang_code_to_id[new_lang] = old_len - 1
|
@@ -28,12 +24,11 @@ def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
|
|
28 |
tokenizer.added_tokens_encoder = {}
|
29 |
tokenizer.added_tokens_decoder = {}
|
30 |
|
31 |
-
fix_tokenizer(big_tokenizer)
|
32 |
fix_tokenizer(small_tokenizer)
|
33 |
|
34 |
# Translation function
|
35 |
-
def translate(text,
|
36 |
-
tokenizer, model =
|
37 |
if src_lang == "zho_Hant":
|
38 |
text = preproc_chinese(text)
|
39 |
tokenizer.src_lang = src_lang
|
@@ -71,7 +66,6 @@ def switch_direction(src_lang, tgt_lang):
|
|
71 |
|
72 |
with gr.Blocks() as demo:
|
73 |
gr.Markdown("# AMIS - Chinese Translation Tool")
|
74 |
-
model_type = gr.Radio(choices=["Small", "Large"], value="Small", label="Model Type")
|
75 |
src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
|
76 |
tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
|
77 |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
|
@@ -79,7 +73,7 @@ with gr.Blocks() as demo:
|
|
79 |
translate_btn = gr.Button("Translate")
|
80 |
switch_btn = gr.Button("Switch Direction")
|
81 |
|
82 |
-
translate_btn.click(translate, inputs=[input_text,
|
83 |
switch_btn.click(switch_direction, inputs=[src_lang, tgt_lang], outputs=[src_lang, tgt_lang])
|
84 |
|
85 |
if __name__ == "__main__":
|
|
|
7 |
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# Load the small model
|
11 |
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
|
12 |
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
|
13 |
|
14 |
+
# Fix tokenizer
|
15 |
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
|
16 |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
|
17 |
tokenizer.lang_code_to_id[new_lang] = old_len - 1
|
|
|
24 |
tokenizer.added_tokens_encoder = {}
|
25 |
tokenizer.added_tokens_decoder = {}
|
26 |
|
|
|
27 |
fix_tokenizer(small_tokenizer)
|
28 |
|
29 |
# Translation function
|
30 |
+
def translate(text, src_lang, tgt_lang):
|
31 |
+
tokenizer, model = small_tokenizer, small_model
|
32 |
if src_lang == "zho_Hant":
|
33 |
text = preproc_chinese(text)
|
34 |
tokenizer.src_lang = src_lang
|
|
|
66 |
|
67 |
with gr.Blocks() as demo:
|
68 |
gr.Markdown("# AMIS - Chinese Translation Tool")
|
|
|
69 |
src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
|
70 |
tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
|
71 |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
|
|
|
73 |
translate_btn = gr.Button("Translate")
|
74 |
switch_btn = gr.Button("Switch Direction")
|
75 |
|
76 |
+
translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
|
77 |
switch_btn.click(switch_direction, inputs=[src_lang, tgt_lang], outputs=[src_lang, tgt_lang])
|
78 |
|
79 |
if __name__ == "__main__":
|