File size: 3,424 Bytes
f9426e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import pandas as pd
import json
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5Tokenizer
import re
from load_model import translate_T5, translate_BERT_BARTPho, translate_LSTM, translate_BiLSTM, translate_GRU, translate_MarianMT
# from load_model import translate_GRU
import sentencepiece as spm


# python -m streamlit run ui.py

MAX_LENGTH = 64

st.title("Machine Translation")
st.markdown('<p style="font-size:24px; font-weight:bold;">English - Vietnamese</p>', 
            unsafe_allow_html=True)

if 'summarize_model' not in st.session_state:
    summarize_model_dir = "./Summarization"
    st.session_state.summarize_tokenizer = AutoTokenizer.from_pretrained(summarize_model_dir)
    st.session_state.summarize_model = AutoModelForSeq2SeqLM.from_pretrained(summarize_model_dir)
    print("Summarize model loaded")
    
model_name = st.selectbox("Select Model", ["BERT_BARTPho" ,"T5", "BiLSTM", "GRU", "LSTM", "MarianMT"], index=None, placeholder="Select a Model")

input_text = st.text_area(
    "Input Text:",
    placeholder="Enter your text here...",
    height=150,
    key="input_text",
    help = f"If your input text is more than {MAX_LENGTH} words. It will be summarized and then translated",
    value= "Today, I go to school"
)

def summarize(input_text):
    if (len(input_text.split()) > MAX_LENGTH):
        st.write("Your input paragraph is more than 64 words!")
        
        summarize_tokenizer = st.session_state.summarize_tokenizer
        summarize_model = st.session_state.summarize_model
        
        inputs = summarize_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        outputs = summarize_model.generate(**inputs, max_length=100, num_beams=5, length_penalty=2.0, early_stopping=True)
        
        summerized_input_text = summarize_tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return summerized_input_text

def cut_sentence(input_text):
    sentences = re.split(r'(?<=[.!?]) +', input_text.strip())
    return sentences

st.write(summarize(input_text))

st.markdown(
    f"""

    <style>

    input[type=text] {{

        width: 500%;

    }}

    </style>

    """,
    unsafe_allow_html=True
)

if st.button("Translate"):
    if model_name == "BERT_BARTPho":
        translated_text = translate_BERT_BARTPho(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)
        
    if model_name == "T5":
        translated_text = translate_T5(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)
    
    if model_name == "BiLSTM":
        translated_text = translate_BiLSTM(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)

    elif model_name == "GRU":
        translated_text = translate_GRU(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)
            
    elif model_name == "LSTM":
        translated_text = translate_LSTM(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)
        
    elif model_name == "MarianMT":
        translated_text = translate_MarianMT(input_text)
        st.write(f"Translation for {model_name}:")
        st.write(translated_text)