|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
checkpoint = "projecte-aina/m2m100-418M-ft-de-ca" |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
|
tokenizer.src_lang = "de" |
|
tokenizer.tgt_lang = "ca" |
|
|
|
def predict(text): |
|
encoded_ref = tokenizer(text, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_ref, forced_bos_token_id=tokenizer.get_lang_id(hyp_lang)) |
|
hyp = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
return hyp[0] |
|
|
|
gr.Interface( |
|
predict, |
|
inputs="text", |
|
outputs="text", |
|
title="Translation", |
|
).launch() |