zmbfeng commited on
Commit
2f33806
·
1 Parent(s): 774748a

new query working

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -123,7 +123,7 @@ if 'is_initialized' not in st.session_state:
123
  st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
124
  st.session_state.paraphrase_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
125
  st.session_state.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws").to('cuda')
126
-
127
  if 'list_count' in st.session_state:
128
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
129
  if 'paragraph_sentence_encodings' not in st.session_state:
@@ -157,7 +157,10 @@ if 'paragraph_sentence_encodings' in st.session_state:
157
  query = st.text_input("Enter your query")
158
 
159
  if query:
160
- if 'paragraph_scores' not in st.session_state:
 
 
 
161
  query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
162
  'cuda')
163
  with torch.no_grad(): # Disable gradient calculation for inference
 
123
  st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
124
  st.session_state.paraphrase_tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
125
  st.session_state.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws").to('cuda')
126
+ print(str(st.session_state.paraphrase_model ))
127
  if 'list_count' in st.session_state:
128
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
129
  if 'paragraph_sentence_encodings' not in st.session_state:
 
157
  query = st.text_input("Enter your query")
158
 
159
  if query:
160
+ if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
161
+ st.session_state.prev_query = query
162
+ if 'paraphrased_paragrpahs' in st.session_state:
163
+ del st.session_state['paraphrased_paragrpahs']
164
  query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
165
  'cuda')
166
  with torch.no_grad(): # Disable gradient calculation for inference