Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import io | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
import time | |
import json | |
from typing import List | |
import torch | |
import random | |
import logging | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device("cpu") | |
logging.warning("GPU not found, using CPU, translation will be very slow.") | |
st.cache(suppress_st_warning=True, allow_output_mutation=True) | |
st.set_page_config(page_title="M2M100 Translator") | |
lang_id = { | |
"English": "en", | |
"French": "fr", | |
} | |
def load_model( | |
pretrained_model: str = "facebook/m2m100_418M", | |
cache_dir: str = "models/", | |
): | |
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) | |
model = M2M100ForConditionalGeneration.from_pretrained( | |
pretrained_model, cache_dir=cache_dir | |
).to(device) | |
model.eval() | |
return tokenizer, model | |
st.title("M2M100 Translator") | |
user_input: str = st.text_area( | |
"Input text", | |
height=200, | |
max_chars=5120, | |
) | |
source_lang = st.selectbox(label="Source language", options=list(lang_id.keys())) | |
target_lang = st.selectbox(label="Target language", options=list(lang_id.keys())) | |
if st.button("Run"): | |
time_start = time.time() | |
tokenizer, model = load_model() | |
src_lang = lang_id[source_lang] | |
trg_lang = lang_id[target_lang] | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(user_input, return_tensors="pt").to(device) | |
generated_tokens = model.generate( | |
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) | |
) | |
translated_text = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
)[0] | |
time_end = time.time() | |
st.success(translated_text) | |
st.write(f"Computation time: {round((time_end-time_start),3)} segs") |