OlzhasBatyrkhanov commited on
Commit
1d61ce5
·
1 Parent(s): 8df2cd3

v1.2.1 translate model done

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -1,51 +1,60 @@
1
  import streamlit as st
2
-
3
  from transformers import T5ForConditionalGeneration, T5Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- device = 'cpu' #or 'cpu' for translate on cpu
 
6
 
7
- model_name = 'utrobinmv/t5_translate_en_ru_zh_large_1024'
8
- model = T5ForConditionalGeneration.from_pretrained(model_name)
9
- model.to(device)
10
- tokenizer = T5Tokenizer.from_pretrained(model_name)
11
 
12
- prefix = 'translate to en: '
13
- src_text = prefix + "Съешь ещё этих мягких французских булок."
14
 
15
- # translate Russian to Chinese
16
- input_ids = tokenizer(src_text, return_tensors="pt")
 
 
 
 
17
 
18
- generated_tokens = model.generate(**input_ids.to(device))
19
 
20
- result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
21
- print(result)
22
- st.write(result[0])
23
- # 再吃这些法国的甜蜜的面包。
24
- # import streamlit as st
25
- # from transformers import pipeline
26
- # import torch
27
- # import scipy
28
 
29
- # st.title("FinalProject")
30
 
31
 
32
- # @st.cache_resource
33
- # def load_summarization_model():
34
- # print("Loading summarization model...")
35
- # return pipeline("summarization", model="facebook/bart-large-cnn")
 
 
 
 
36
 
37
- # summarizer = load_summarization_model()
38
 
39
- # ARTICLE = st.text_area("Enter the article to summarize:", height=300)
 
 
 
40
 
41
- # max_length = st.number_input("Enter max length for summary:", min_value=10, max_value=500, value=130)
42
- # min_length = st.number_input("Enter min length for summary:", min_value=5, max_value=450, value=30)
 
43
 
44
- # if st.button("Summarize"):
45
- # if ARTICLE.strip():
46
- # answer = summarizer(ARTICLE, max_length=int(max_length), min_length=int(min_length), do_sample=False)
47
- # summary = answer[0]['summary_text']
48
- # st.write("### Summary:")
49
- # st.write(summary)
50
- # else:
51
- # st.error("Please enter an article to summarize.")
 
1
  import streamlit as st
 
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ from transformers import pipeline
4
+ import torch
5
+ import scipy
6
+
7
+ st.title("FinalProject")
8
+
9
+
10
+ @st.cache_resource
11
+ def load_summarization_model():
12
+ print("Loading summarization model...")
13
+ return pipeline("summarization", model="facebook/bart-large-cnn")
14
+
15
+ summarizer = load_summarization_model()
16
+
17
+ ARTICLE = st.text_area("Enter the article to summarize:", height=300)
18
 
19
+ max_length = st.number_input("Enter max length for summary:", min_value=10, max_value=500, value=130)
20
+ min_length = st.number_input("Enter min length for summary:", min_value=5, max_value=450, value=30)
21
 
 
 
 
 
22
 
23
+ device = 'cpu'
 
24
 
25
+ @st.cache_resource
26
+ def load_translation_model():
27
+ model_name = 'utrobinmv/t5_translate_en_ru_zh_large_1024'
28
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
29
+ model.to(device)
30
+ return model, T5Tokenizer.from_pretrained(model_name)
31
 
 
32
 
 
 
 
 
 
 
 
 
33
 
34
+ model, tokenizer = load_translation_model()
35
 
36
 
37
+ if st.button("Summarize"):
38
+ if ARTICLE.strip():
39
+ answer = summarizer(ARTICLE, max_length=int(max_length), min_length=int(min_length), do_sample=False)
40
+ summary = answer[0]['summary_text']
41
+ st.write("### Summary:")
42
+ st.write(summary)
43
+ else:
44
+ st.error("Please enter an article to summarize.")
45
 
46
+ target_language = st.selectbox("Choose target language for translation:", ["ru", "zh"])
47
 
48
+ if st.button("Translate"):
49
+ if ARTICLE.strip():
50
+ prefix = f"translate to {target_language}: "
51
+ src_text = prefix + ARTICLE
52
 
53
+ input_ids = tokenizer(src_text, return_tensors="pt")
54
+ generated_tokens = model.generate(**input_ids.to(device))
55
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
56
 
57
+ st.write(f"### Translation ({target_language.upper()}):")
58
+ st.write(result[0])
59
+ else:
60
+ st.error("Please enter an article to translate.")