Stefano Fiorucci commited on
Commit
f79a364
1 Parent(s): 429cb7c

try to get better cache

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -20,7 +20,7 @@ from urllib.parse import unquote
20
 
21
  # FAISS index directory
22
  INDEX_DIR = 'data/index'
23
-
24
 
25
  # the following function is cached to make index and models load only at start
26
  @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True)
@@ -56,12 +56,13 @@ def set_state_if_absent(key, value):
56
  st.session_state[key] = value
57
 
58
  # hash_funcs={builtins.weakref: my_hash_func}
59
- # @st.cache(persist=True, hash_funcs={"builtins.weakref": lambda _: None}, allow_output_mutation=True)
60
- def query(pipe, question, retriever_top_k=10, reader_top_k=5) -> dict:
61
  """Run query and get answers"""
62
- return (pipe.run(question,
63
- params={"Retriever": {"top_k": retriever_top_k},
64
- "Reader": {"top_k": reader_top_k}}), None)
 
65
 
66
 
67
  def main():
@@ -190,7 +191,7 @@ and see if the AI ​​can find an answer...
190
 
191
  ):
192
  try:
193
- st.session_state.results, st.session_state.raw_json = query(pipe, question)
194
  time_end=time.time()
195
  print(f'elapsed time: {time_end - time_start}')
196
  except JSONDecodeError as je:
 
20
 
21
  # FAISS index directory
22
  INDEX_DIR = 'data/index'
23
+ pipe=None
24
 
25
  # the following function is cached to make index and models load only at start
26
  @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True)
 
56
  st.session_state[key] = value
57
 
58
  # hash_funcs={builtins.weakref: my_hash_func}
59
+ @st.cache(persist=True, allow_output_mutation=True)
60
+ def query(question, retriever_top_k=10, reader_top_k=5) -> dict:
61
  """Run query and get answers"""
62
+ params = {"Retriever": {"top_k": retriever_top_k},
63
+ "Reader": {"top_k": reader_top_k}}
64
+ results = pipe.run(question, params=params)
65
+ return results
66
 
67
 
68
  def main():
 
191
 
192
  ):
193
  try:
194
+ st.session_state.results = query(pipe)
195
  time_end=time.time()
196
  print(f'elapsed time: {time_end - time_start}')
197
  except JSONDecodeError as je: