File size: 1,964 Bytes
3a49715
e981b8c
 
 
 
 
 
 
 
 
3a49715
e981b8c
 
 
 
 
3a49715
e981b8c
 
3a49715
e981b8c
 
 
d004a35
e981b8c
3a49715
e981b8c
d004a35
3a49715
e981b8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a49715
 
e981b8c
 
3a49715
e981b8c
 
 
3a49715
e981b8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
import os
import io
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import time
import json
from typing import List
import torch
import random
import logging

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    logging.warning("GPU not found, using CPU, translation will be very slow.")

st.cache(suppress_st_warning=True, allow_output_mutation=True)
st.set_page_config(page_title="M2M100 Translator")

lang_id = {
    
    "English": "en",

    "French": "fr",

}


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(
    pretrained_model: str = "facebook/m2m100_418M",
    cache_dir: str = "models/",
):
    tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
    model = M2M100ForConditionalGeneration.from_pretrained(
        pretrained_model, cache_dir=cache_dir
    ).to(device)
    model.eval()
    return tokenizer, model


st.title("M2M100 Translator")


user_input: str = st.text_area(
    "Input text",
    height=200,
    max_chars=5120,
)

source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))

if st.button("Run"):
    time_start = time.time()
    tokenizer, model = load_model()

    src_lang = lang_id[source_lang]
    trg_lang = lang_id[target_lang]
    tokenizer.src_lang = src_lang
    with torch.no_grad():
        encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
        )
        translated_text = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )[0]

    time_end = time.time()
    st.success(translated_text)

    st.write(f"Computation time: {round((time_end-time_start),3)} segs")