File size: 3,185 Bytes
640a35c
4af5544
 
 
 
 
52f4023
640a35c
4af5544
640a35c
4af5544
 
 
640a35c
b551379
4af5544
 
 
 
 
 
 
 
 
 
 
640a35c
4af5544
640a35c
4af5544
2e7a521
b551379
4af5544
 
 
 
 
 
 
 
 
 
 
 
 
640a35c
4af5544
 
 
640a35c
4af5544
 
 
640a35c
4af5544
640a35c
4af5544
 
 
 
 
 
640a35c
4af5544
 
2e7a521
 
4af5544
 
 
2e7a521
 
640a35c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import torch
from sacremoses import MosesPunctNormalizer
import re
import unicodedata
import sys

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the small model
small_tokenizer = NllbTokenizer.from_pretrained("hunterschep/amis-zh-600M")
small_model = AutoModelForSeq2SeqLM.from_pretrained("hunterschep/amis-zh-600M").to(device)

# Fix tokenizer
def fix_tokenizer(tokenizer, new_lang='ami_Latn'):
    old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
    tokenizer.lang_code_to_id[new_lang] = old_len - 1
    tokenizer.id_to_lang_code[old_len - 1] = new_lang
    tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
    tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
    tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
    if new_lang not in tokenizer._additional_special_tokens:
        tokenizer._additional_special_tokens.append(new_lang)
    tokenizer.added_tokens_encoder = {}
    tokenizer.added_tokens_decoder = {}

fix_tokenizer(small_tokenizer)

# Translation function
def translate(text, src_lang, tgt_lang):
    tokenizer, model = small_tokenizer, small_model
    if src_lang == "zho_Hant":
        text = preproc_chinese(text)
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    model.eval()
    result = model.generate(
        **inputs.to(model.device),
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
        max_new_tokens=256,
        num_beams=4
    )
    return tokenizer.batch_decode(result, skip_special_tokens=True)[0]

# Preprocessing for Chinese
mpn_chinese = MosesPunctNormalizer(lang="zh")
mpn_chinese.substitutions = [(re.compile(r), sub) for r, sub in mpn_chinese.substitutions]

def get_non_printing_char_replacer(replace_by=" "):
    non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
    return lambda line: line.translate(non_printable_map)

replace_nonprint = get_non_printing_char_replacer(" ")

def preproc_chinese(text):
    clean = text
    for pattern, sub in mpn_chinese.substitutions:
        clean = pattern.sub(sub, clean)
    clean = replace_nonprint(clean)
    return unicodedata.normalize("NFKC", clean)

with gr.Blocks() as demo:
    gr.Markdown("# AMIS - Chinese Translation Tool")
    src_lang = gr.Radio(choices=["zho_Hant", "ami_Latn"], value="zho_Hant", label="Source Language")
    tgt_lang = gr.Radio(choices=["ami_Latn", "zho_Hant"], value="ami_Latn", label="Target Language")
    input_text = gr.Textbox(label="Input Text", placeholder="Enter text here...")
    output_text = gr.Textbox(label="Translated Text", interactive=False)
    translate_btn = gr.Button("Translate")

    translate_btn.click(translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)

if __name__ == "__main__":
    demo.launch()