Spaces:
Runtime error
Runtime error
File size: 1,964 Bytes
3a49715 e981b8c 3a49715 e981b8c 3a49715 e981b8c 3a49715 e981b8c d004a35 e981b8c 3a49715 e981b8c d004a35 3a49715 e981b8c 3a49715 e981b8c 3a49715 e981b8c 3a49715 e981b8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import streamlit as st
import os
import io
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import time
import json
from typing import List
import torch
import random
import logging
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
logging.warning("GPU not found, using CPU, translation will be very slow.")
st.cache(suppress_st_warning=True, allow_output_mutation=True)
st.set_page_config(page_title="M2M100 Translator")
lang_id = {
"English": "en",
"French": "fr",
}
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(
pretrained_model: str = "facebook/m2m100_418M",
cache_dir: str = "models/",
):
tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
model = M2M100ForConditionalGeneration.from_pretrained(
pretrained_model, cache_dir=cache_dir
).to(device)
model.eval()
return tokenizer, model
st.title("M2M100 Translator")
user_input: str = st.text_area(
"Input text",
height=200,
max_chars=5120,
)
source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
if st.button("Run"):
time_start = time.time()
tokenizer, model = load_model()
src_lang = lang_id[source_lang]
trg_lang = lang_id[target_lang]
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
)
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
time_end = time.time()
st.success(translated_text)
st.write(f"Computation time: {round((time_end-time_start),3)} segs") |