nan-motherboard commited on
Commit
fa77629
·
1 Parent(s): d38c13b
Files changed (3) hide show
  1. __pycache__/utils.cpython-311.pyc +0 -0
  2. app.py +50 -2
  3. utils.py +57 -0
__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.96 kB). View file
 
app.py CHANGED
@@ -1,4 +1,52 @@
1
  import streamlit as st
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from utils import generate_summary
3
 
4
+ # Initialize session state variables
5
+ if "clicked" not in st.session_state:
6
+ st.session_state.clicked = False
7
+ if "input_text" not in st.session_state:
8
+ st.session_state.input_text = ""
9
+ if "generated_summary" not in st.session_state:
10
+ st.session_state.generated_summary = ""
11
+
12
+ st.title("Dialogue Text Summarization")
13
+
14
+ st.write("---")
15
+
16
+ height = 200
17
+
18
+ # Text area with session state
19
+ input_text = st.text_area("Dialogue", height=height, key="input_text")
20
+
21
+ # Submit button logic
22
+ if st.button("Submit"):
23
+ if st.session_state.input_text.strip() == "":
24
+ st.error("Please enter a dialogue!")
25
+ else:
26
+ st.write("---")
27
+ st.write("## Summary")
28
+ st_container = st.empty()
29
+ st_info_container = st.empty()
30
+ # Generate summary and store it in session state
31
+ st.session_state.generated_summary = generate_summary(
32
+ " ".join(st.session_state.input_text.split()),
33
+ st_container,
34
+ st_info_container
35
+ )
36
+
37
+ # Display the generated summary
38
+ if st.session_state.generated_summary:
39
+ st.write(st.session_state.generated_summary)
40
+
41
+ # Clear button logic
42
+ def clear_all():
43
+ st.session_state.clicked = True
44
+ st.session_state.input_text = "" # Clear input text
45
+ st.session_state.generated_summary = "" # Clear summary
46
+
47
+ st.button("Clear", on_click=clear_all)
48
+
49
+ # Logic for clearing display
50
+ if st.session_state.clicked:
51
+ st.session_state.clicked = False
52
+ st.experimental_rerun()
utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, GenerationConfig, TextStreamer, AutoModelForSeq2SeqLM
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import time
5
+
6
+ checkpoint = "Mia2024/CS5100TextSummarization"
7
+ checkpoint = "facebook/bart-large-cnn"
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+
11
+ class StreamlitTextStreamer(TextStreamer):
12
+ def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
13
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
14
+ self.st_container = st_container
15
+ self.st_info_container = st_info_container
16
+ self.text = ""
17
+ self.start_time = None
18
+ self.first_token_time = None
19
+ self.total_tokens = 0
20
+
21
+ def on_finalized_text(self, text: str, stream_end: bool=False):
22
+ if self.start_time is None:
23
+ self.start_time = time.time()
24
+
25
+ if self.first_token_time is None and len(text.strip()) > 0:
26
+ self.first_token_time = time.time()
27
+
28
+ self.text += text
29
+
30
+ self.total_tokens += len(text.split())
31
+ self.st_container.markdown("###### " + self.text)
32
+ time.sleep(0.03)
33
+
34
+
35
+ def generate_summary(input_text, st_container, st_info_container) -> str:
36
+ generation_config = GenerationConfig(
37
+ min_new_tokens=10,
38
+ max_new_tokens=256,
39
+ temperature=0.9,
40
+ top_p=1.0,
41
+ top_k=50
42
+ )
43
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
45
+ prefix = "Summarize the following conversation: \n###\n"
46
+ suffix = "\n### Summary:"
47
+ target_length = max(1, int(0.15 * len(input_text.split())))
48
+
49
+ input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")
50
+
51
+ # Initialize the Streamlit container and streamer
52
+ streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)
53
+
54
+ model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)
55
+
56
+
57
+