de-ca / app.py
ksenia-kh's picture
Update app.py
637d81a
raw
history blame contribute delete
673 Bytes
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
checkpoint = "projecte-aina/m2m100-418M-ft-de-ca"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.src_lang = "de"
tokenizer.tgt_lang = "ca"
def predict(text):
encoded_ref = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_ref, forced_bos_token_id=tokenizer.get_lang_id(hyp_lang))
hyp = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return hyp[0]
gr.Interface(
predict,
inputs="text",
outputs="text",
title="Translation",
).launch()