kiyer commited on
Commit
1fa5fdb
·
1 Parent(s): a1e4f2c

bugfix for semantic_search

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -140,20 +140,20 @@ if 'arxiv_corpus' not in st.session_state:
140
  st.session_state.arxiv_corpus = arxiv_corpus
141
  st.toast('loaded arxiv corpus')
142
 
143
- if 'ids' not in st.session_state:
144
- with st.spinner('making the LLM talk to the astro papers...'):
145
- st.session_state.ids = st.session_state.arxiv_corpus['ads_id']
146
- st.session_state.titles = st.session_state.arxiv_corpus['title']
147
- st.session_state.abstracts = st.session_state.arxiv_corpus['abstract']
148
- st.session_state.authors = st.session_state.arxiv_corpus['authors']
149
- st.session_state.cites = st.session_state.arxiv_corpus['cites']
150
- st.session_state.years = st.session_state.arxiv_corpus['date']
151
- st.session_state.kws = st.session_state.arxiv_corpus['keywords']
152
- st.session_state.ads_kws = st.session_state.arxiv_corpus['ads_keywords']
153
- st.session_state.bibcode = st.session_state.arxiv_corpus['bibcode']
154
- st.session_state.umap_x = st.session_state.arxiv_corpus['umap_x']
155
- st.session_state.umap_y = st.session_state.arxiv_corpus['umap_y']
156
- st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
157
 
158
  def get_keywords(text):
159
  result = []
@@ -192,6 +192,8 @@ class EmbeddingRetrievalSystem():
192
  self.weight_date = weight_date
193
  self.weight_keywords = weight_keywords
194
  self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))}
 
 
195
 
196
  # self.citation_filter = CitationFilter(self.dataset)
197
  # self.date_filter = DateFilter(self.dataset['date'])
@@ -339,7 +341,7 @@ class HydeRetrievalSystem(EmbeddingRetrievalSystem):
339
 
340
  # self.anthropic_key = anthropic_key
341
  # self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
342
- self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
343
 
344
  def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
345
  if time_result is None:
 
140
  st.session_state.arxiv_corpus = arxiv_corpus
141
  st.toast('loaded arxiv corpus')
142
 
143
+ if 'ids' not in st.session_state:
144
+ with st.spinner('making the LLM talk to the astro papers...'):
145
+ st.session_state.ids = st.session_state.arxiv_corpus['ads_id']
146
+ st.session_state.titles = st.session_state.arxiv_corpus['title']
147
+ st.session_state.abstracts = st.session_state.arxiv_corpus['abstract']
148
+ st.session_state.authors = st.session_state.arxiv_corpus['authors']
149
+ st.session_state.cites = st.session_state.arxiv_corpus['cites']
150
+ st.session_state.years = st.session_state.arxiv_corpus['date']
151
+ st.session_state.kws = st.session_state.arxiv_corpus['keywords']
152
+ st.session_state.ads_kws = st.session_state.arxiv_corpus['ads_keywords']
153
+ st.session_state.bibcode = st.session_state.arxiv_corpus['bibcode']
154
+ st.session_state.umap_x = st.session_state.arxiv_corpus['umap_x']
155
+ st.session_state.umap_y = st.session_state.arxiv_corpus['umap_y']
156
+ st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
157
 
158
  def get_keywords(text):
159
  result = []
 
192
  self.weight_date = weight_date
193
  self.weight_keywords = weight_keywords
194
  self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))}
195
+
196
+ self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
197
 
198
  # self.citation_filter = CitationFilter(self.dataset)
199
  # self.date_filter = DateFilter(self.dataset['date'])
 
341
 
342
  # self.anthropic_key = anthropic_key
343
  # self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
344
+
345
 
346
  def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
347
  if time_result is None: