Jay7478 commited on
Commit
f3ca51f
·
1 Parent(s): 9db796e

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -118
app.py DELETED
@@ -1,118 +0,0 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import nltk
4
- import math
5
- import torch
6
-
7
- model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
8
- # model_name = "afnanmmir/t5-base-axriv-to-abstract-3"
9
- max_input_length = 1024
10
- max_output_length = 256
11
-
12
- st.header("Generate summaries")
13
-
14
- st_model_load = st.text('Loading summary generator model...')
15
-
16
- # @st.cache(allow_output_mutation=True)
17
- @st.cache_data
18
- def load_model():
19
- print("Loading model...")
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
- nltk.download('punkt')
23
- print("Model loaded!")
24
- return tokenizer, model
25
-
26
- tokenizer, model = load_model()
27
- st.success('Model loaded!')
28
- st_model_load.text("")
29
-
30
- # with st.sidebar:
31
- # st.header("Model parameters")
32
- # if 'num_titles' not in st.session_state:
33
- # st.session_state.num_titles = 5
34
- # def on_change_num_titles():
35
- # st.session_state.num_titles = num_titles
36
- # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
37
- # if 'temperature' not in st.session_state:
38
- # st.session_state.temperature = 0.7
39
- # def on_change_temperatures():
40
- # st.session_state.temperature = temperature
41
- # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
42
- # st.markdown("_High temperature means that results are more random_")
43
-
44
- if 'text' not in st.session_state:
45
- st.session_state.text = ""
46
- st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
47
-
48
- def generate_summary():
49
- st.session_state.text = st_text_area
50
-
51
- # tokenize text
52
- inputs = ["summarize: " + st_text_area]
53
- # print(inputs)
54
- inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
55
- # print("Tokenized inputs: ")
56
- # print(inputs)
57
-
58
- # inputs = tokenizer(inputs, return_tensors="pt")
59
-
60
- # # compute span boundaries
61
- # num_tokens = len(inputs["input_ids"][0])
62
- # print(f"Input has {num_tokens} tokens")
63
- # max_input_length = 500
64
- # num_spans = math.ceil(num_tokens / max_input_length)
65
- # print(f"Input has {num_spans} spans")
66
- # overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
67
- # spans_boundaries = []
68
- # start = 0
69
- # for i in range(num_spans):
70
- # spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
71
- # start -= overlap
72
- # print(f"Span boundaries are {spans_boundaries}")
73
- # spans_boundaries_selected = []
74
- # j = 0
75
- # for _ in range(num_titles):
76
- # spans_boundaries_selected.append(spans_boundaries[j])
77
- # j += 1
78
- # if j == len(spans_boundaries):
79
- # j = 0
80
- # print(f"Selected span boundaries are {spans_boundaries_selected}")
81
-
82
- # # transform input with spans
83
- # tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
84
- # tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
85
-
86
- # inputs = {
87
- # "input_ids": torch.stack(tensor_ids),
88
- # "attention_mask": torch.stack(tensor_masks)
89
- # }
90
-
91
- # compute predictions
92
- # outputs = model.generate(**inputs, do_sample=True, temperature=temperature, max_length=max_output_length)
93
-
94
- outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=64)
95
- # print("outputs", outputs)
96
- decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
97
- # print("Decoded_outputs", decoded_outputs)
98
- predicted_summaries = nltk.sent_tokenize(decoded_outputs.strip())
99
- print("Predicted summaries", predicted_summaries)
100
-
101
- # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
102
- # predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
103
-
104
- st.session_state.summaries = predicted_summaries
105
-
106
- # generate title button
107
- st_generate_button = st.button('Generate summary', on_click=generate_summary)
108
-
109
- # title generation labels
110
- if 'summaries' not in st.session_state:
111
- st.session_state.summaries = []
112
-
113
- if len(st.session_state.summaries) > 0:
114
- print("In summaries if")
115
- with st.container():
116
- st.subheader("Generated summaries")
117
- for summary in st.session_state.summaries:
118
- st.markdown("__" + summary + "__")