import streamlit as st import os import io from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration import time import json from typing import List import torch import random import logging if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") logging.warning("GPU not found, using CPU, translation will be very slow.") st.cache(suppress_st_warning=True, allow_output_mutation=True) st.set_page_config(page_title="M2M100 Translator") lang_id = { "English": "en", "French": "fr", } @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_model( pretrained_model: str = "facebook/m2m100_418M", cache_dir: str = "models/", ): tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) model = M2M100ForConditionalGeneration.from_pretrained( pretrained_model, cache_dir=cache_dir ).to(device) model.eval() return tokenizer, model st.title("M2M100 Translator") user_input: str = st.text_area( "Input text", height=200, max_chars=5120, ) source_lang = st.selectbox(label="Source language", options=list(lang_id.keys())) target_lang = st.selectbox(label="Target language", options=list(lang_id.keys())) if st.button("Run"): time_start = time.time() tokenizer, model = load_model() src_lang = lang_id[source_lang] trg_lang = lang_id[target_lang] tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(user_input, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] time_end = time.time() st.success(translated_text) st.write(f"Computation time: {round((time_end-time_start),3)} segs")