import gradio as gr from transformers import AutoModelForSeq2SeqLM, NllbTokenizer import torch from sacremoses import MosesPunctNormalizer import re import unicodedata import sys device = "cuda" if torch.cuda.is_available() else "cpu" # Load the small model small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M") small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device) # Fix tokenizer def fix_tokenizer(tokenizer, new_lang='ami_Latn'): old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[new_lang] = old_len - 1 tokenizer.id_to_lang_code[old_len - 1] = new_lang tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} if new_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(new_lang) tokenizer.added_tokens_encoder = {} tokenizer.added_tokens_decoder = {} fix_tokenizer(small_tokenizer) # Translation function def translate(text, src_lang, tgt_lang): tokenizer, model = small_tokenizer, small_model if src_lang == "zho_Hant": text = preproc_chinese(text) tokenizer.src_lang = src_lang tokenizer.tgt_lang = tgt_lang inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024) model.eval() result = model.generate( **inputs.to(model.device), forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), max_new_tokens=256, num_beams=4 ) return tokenizer.batch_decode(result, skip_special_tokens=True)[0] # Preprocessing for Chinese mpn_chinese = MosesPunctNormalizer(lang="zh") mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions] def get_non_printing_char_replacer(replace_by=" "): non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}} return lambda line: line.translate(non_printable_map) replace_nonprint = get_non_printing_char_replacer(" ") def preproc_chinese(text): clean = text for pattern, sub in mpn_chinese.substitutions: clean = pattern.sub(sub, clean) clean = replace_nonprint(clean) return unicodedata.normalize("NFKC", clean) with gr.Blocks() as demo: gr.Markdown("# AMIS - Chinese Translation Tool") src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language") tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language") input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...") output_text = gr.Textbox(label="Translated Text", interactive=False) translate_btn = gr.Button("Translate") translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text) if __name__ == "__main__": demo.launch()