darylalim's picture
Update app.py
a2533a3 verified
raw
history blame
1.14 kB
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from optimum.bettertransformer import BetterTransformer
tokenizer = AutoTokenizer.from_pretrained(
"google/madlad400-3b-mt",
use_fast=True
)
model_hf = AutoModelForSeq2SeqLM.from_pretrained(
"google/madlad400-3b-mt",
torch_dtype=torch.bfloat16
)
model = BetterTransformer.transform(model_hf, keep_original=True)
@spaces.GPU
def translate(text):
"""
Translates the input text from English to Hawaiian.
"""
text = "<2haw> " + text
inputs = tokenizer(
text,
return_tensors="pt"
)
outputs = model.generate(**inputs, max_new_tokens=1000)
text_translated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return text_translated[0]
demo = gr.Interface(
fn=translate,
inputs=[gr.Textbox(label="English")],
outputs=[gr.Textbox(label="Hawaiian")],
title="MADLAD-400-3B-MT English-to-Hawaiian Translation",
description="[Code](https://github.com/darylalim/madlad-400-3b-mt-eng-to-haw-translation)")
demo.queue()
demo.launch()