zmbfeng commited on
Commit
43d2a91
·
1 Parent(s): 2f33806

show GUI after init

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -69,6 +69,46 @@ def paraphrase(sentence):
69
  #results.append(line)
70
  return line
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  big_text = """
73
  <div style='text-align: center;'>
74
  <h1 style='font-size: 30x;'>Knowledge Extraction A</h1>
@@ -113,46 +153,7 @@ if uploaded_json_file is not None:
113
  except json.JSONDecodeError:
114
  st.write('Invalid JSON file.')
115
  st.rerun()
116
- if 'is_initialized' not in st.session_state:
117
- st.session_state['is_initialized'] = True
118
-
119
- nltk.download('punkt')
120
- nltk.download('stopwords')
121
- st.session_state.stop_words = set(stopwords.words('english'))
122
- st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
123
- st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
124
- st.session_state.paraphrase_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
125
- st.session_state.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws").to('cuda')
126
- print(str(st.session_state.paraphrase_model ))
127
- if 'list_count' in st.session_state:
128
- st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
129
- if 'paragraph_sentence_encodings' not in st.session_state:
130
- print("start embedding paragarphs")
131
- read_progress_bar = st.progress(0)
132
- st.session_state.paragraph_sentence_encodings = []
133
- for index,paragraph in enumerate(st.session_state.restored_paragraphs):
134
- #print(paragraph)
135
 
136
- progress_percentage = (index) / (st.session_state.list_count - 1)
137
- # print(progress_percentage)
138
- read_progress_bar.progress(progress_percentage)
139
-
140
- sentence_encodings = []
141
- sentences = sent_tokenize(paragraph['text'])
142
- for sentence in sentences:
143
- if sentence.strip().endswith('?'):
144
- sentence_encodings.append(None)
145
- continue
146
- if len(sentence.strip()) < 4:
147
- sentence_encodings.append(None)
148
- continue
149
- sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
150
- with torch.no_grad():
151
- sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
152
- sentence_encodings.append([sentence, sentence_encoding])
153
- # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
154
- st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
155
- st.rerun()
156
  if 'paragraph_sentence_encodings' in st.session_state:
157
  query = st.text_input("Enter your query")
158
 
 
69
  #results.append(line)
70
  return line
71
 
72
+ if 'is_initialized' not in st.session_state:
73
+ st.session_state['is_initialized'] = True
74
+
75
+ nltk.download('punkt')
76
+ nltk.download('stopwords')
77
+ st.session_state.stop_words = set(stopwords.words('english'))
78
+ st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
79
+ st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
80
+ st.session_state.paraphrase_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
81
+ st.session_state.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws").to('cuda')
82
+ print(str(st.session_state.paraphrase_model ))
83
+ if 'list_count' in st.session_state:
84
+ st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
85
+ if 'paragraph_sentence_encodings' not in st.session_state:
86
+ print("start embedding paragarphs")
87
+ read_progress_bar = st.progress(0)
88
+ st.session_state.paragraph_sentence_encodings = []
89
+ for index,paragraph in enumerate(st.session_state.restored_paragraphs):
90
+ #print(paragraph)
91
+
92
+ progress_percentage = (index) / (st.session_state.list_count - 1)
93
+ # print(progress_percentage)
94
+ read_progress_bar.progress(progress_percentage)
95
+
96
+ sentence_encodings = []
97
+ sentences = sent_tokenize(paragraph['text'])
98
+ for sentence in sentences:
99
+ if sentence.strip().endswith('?'):
100
+ sentence_encodings.append(None)
101
+ continue
102
+ if len(sentence.strip()) < 4:
103
+ sentence_encodings.append(None)
104
+ continue
105
+ sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
106
+ with torch.no_grad():
107
+ sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
108
+ sentence_encodings.append([sentence, sentence_encoding])
109
+ # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
110
+ st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
111
+ st.rerun()
112
  big_text = """
113
  <div style='text-align: center;'>
114
  <h1 style='font-size: 30x;'>Knowledge Extraction A</h1>
 
153
  except json.JSONDecodeError:
154
  st.write('Invalid JSON file.')
155
  st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if 'paragraph_sentence_encodings' in st.session_state:
158
  query = st.text_input("Enter your query")
159