File size: 1,892 Bytes
06a326a
e6711f4
1d61ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
16521bd
1d61ce5
 
16521bd
68c03aa
1d61ce5
16521bd
1d61ce5
 
 
 
 
 
93e2510
8df2cd3
 
1d61ce5
8df2cd3
 
1d61ce5
 
 
 
 
 
 
 
8df2cd3
1d61ce5
8df2cd3
1d61ce5
 
 
 
8df2cd3
1d61ce5
 
 
8df2cd3
1d61ce5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch
import scipy

st.title("FinalProject")


@st.cache_resource
def load_summarization_model():
    print("Loading summarization model...")
    return pipeline("summarization", model="facebook/bart-large-cnn")

summarizer = load_summarization_model()

ARTICLE = st.text_area("Enter the article to summarize:", height=300)

max_length = st.number_input("Enter max length for summary:", min_value=10, max_value=500, value=130)
min_length = st.number_input("Enter min length for summary:", min_value=5, max_value=450, value=30)


device = 'cpu'

@st.cache_resource
def load_translation_model():
    model_name = 'utrobinmv/t5_translate_en_ru_zh_large_1024'
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model.to(device)
    return model, T5Tokenizer.from_pretrained(model_name)



model, tokenizer = load_translation_model()


if st.button("Summarize"):
    if ARTICLE.strip():
        answer = summarizer(ARTICLE, max_length=int(max_length), min_length=int(min_length), do_sample=False)
        summary = answer[0]['summary_text']
        st.write("### Summary:")
        st.write(summary)
    else:
        st.error("Please enter an article to summarize.")

target_language = st.selectbox("Choose target language for translation:", ["ru", "zh"])

if st.button("Translate"):
    if ARTICLE.strip():
        prefix = f"translate to {target_language}: "
        src_text = prefix + ARTICLE

        input_ids = tokenizer(src_text, return_tensors="pt")
        generated_tokens = model.generate(**input_ids.to(device))
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        st.write(f"### Translation ({target_language.upper()}):")
        st.write(result[0])
    else:
        st.error("Please enter an article to translate.")