Spaces:
Running
Running
kiyer
commited on
Commit
·
1fa5fdb
1
Parent(s):
a1e4f2c
bugfix for semantic_search
Browse files
app.py
CHANGED
@@ -140,20 +140,20 @@ if 'arxiv_corpus' not in st.session_state:
|
|
140 |
st.session_state.arxiv_corpus = arxiv_corpus
|
141 |
st.toast('loaded arxiv corpus')
|
142 |
|
143 |
-
if 'ids' not in st.session_state:
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
def get_keywords(text):
|
159 |
result = []
|
@@ -192,6 +192,8 @@ class EmbeddingRetrievalSystem():
|
|
192 |
self.weight_date = weight_date
|
193 |
self.weight_keywords = weight_keywords
|
194 |
self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))}
|
|
|
|
|
195 |
|
196 |
# self.citation_filter = CitationFilter(self.dataset)
|
197 |
# self.date_filter = DateFilter(self.dataset['date'])
|
@@ -339,7 +341,7 @@ class HydeRetrievalSystem(EmbeddingRetrievalSystem):
|
|
339 |
|
340 |
# self.anthropic_key = anthropic_key
|
341 |
# self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
|
342 |
-
|
343 |
|
344 |
def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
|
345 |
if time_result is None:
|
|
|
140 |
st.session_state.arxiv_corpus = arxiv_corpus
|
141 |
st.toast('loaded arxiv corpus')
|
142 |
|
143 |
+
if 'ids' not in st.session_state:
|
144 |
+
with st.spinner('making the LLM talk to the astro papers...'):
|
145 |
+
st.session_state.ids = st.session_state.arxiv_corpus['ads_id']
|
146 |
+
st.session_state.titles = st.session_state.arxiv_corpus['title']
|
147 |
+
st.session_state.abstracts = st.session_state.arxiv_corpus['abstract']
|
148 |
+
st.session_state.authors = st.session_state.arxiv_corpus['authors']
|
149 |
+
st.session_state.cites = st.session_state.arxiv_corpus['cites']
|
150 |
+
st.session_state.years = st.session_state.arxiv_corpus['date']
|
151 |
+
st.session_state.kws = st.session_state.arxiv_corpus['keywords']
|
152 |
+
st.session_state.ads_kws = st.session_state.arxiv_corpus['ads_keywords']
|
153 |
+
st.session_state.bibcode = st.session_state.arxiv_corpus['bibcode']
|
154 |
+
st.session_state.umap_x = st.session_state.arxiv_corpus['umap_x']
|
155 |
+
st.session_state.umap_y = st.session_state.arxiv_corpus['umap_y']
|
156 |
+
st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
|
157 |
|
158 |
def get_keywords(text):
|
159 |
result = []
|
|
|
192 |
self.weight_date = weight_date
|
193 |
self.weight_keywords = weight_keywords
|
194 |
self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))}
|
195 |
+
|
196 |
+
self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
|
197 |
|
198 |
# self.citation_filter = CitationFilter(self.dataset)
|
199 |
# self.date_filter = DateFilter(self.dataset['date'])
|
|
|
341 |
|
342 |
# self.anthropic_key = anthropic_key
|
343 |
# self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
|
344 |
+
|
345 |
|
346 |
def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
|
347 |
if time_result is None:
|