Spaces:
Running
Running
File size: 3,185 Bytes
640a35c 4af5544 52f4023 640a35c 4af5544 640a35c 4af5544 640a35c b551379 4af5544 640a35c 4af5544 640a35c 4af5544 2e7a521 b551379 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 640a35c 4af5544 2e7a521 4af5544 2e7a521 640a35c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import torch
from sacremoses import MosesPunctNormalizer
import re
import unicodedata
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the small model
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)
# Fix tokenizer
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len - 1
tokenizer.id_to_lang_code[old_len - 1] = new_lang
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
fix_tokenizer(small_tokenizer)
# Translation function
def translate(text, src_lang, tgt_lang):
tokenizer, model = small_tokenizer, small_model
if src_lang == "zho_Hant":
text = preproc_chinese(text)
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
model.eval()
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=256,
num_beams=4
)
return tokenizer.batch_decode(result, skip_special_tokens=True)[0]
# Preprocessing for Chinese
mpn_chinese = MosesPunctNormalizer(lang="zh")
mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions]
def get_non_printing_char_replacer(replace_by=" "):
non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
return lambda line: line.translate(non_printable_map)
replace_nonprint = get_non_printing_char_replacer(" ")
def preproc_chinese(text):
clean = text
for pattern, sub in mpn_chinese.substitutions:
clean = pattern.sub(sub, clean)
clean = replace_nonprint(clean)
return unicodedata.normalize("NFKC", clean)
with gr.Blocks() as demo:
gr.Markdown("# AMIS - Chinese Translation Tool")
src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
output_text = gr.Textbox(label="Translated Text", interactive=False)
translate_btn = gr.Button("Translate")
translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
if __name__ == "__main__":
demo.launch()
|