ppsingh commited on
Commit
1befddb
·
verified ·
1 Parent(s): f42601b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import pandas as pd
 
3
  import numpy as np
4
  import os
5
  import time
@@ -22,6 +23,9 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
22
  from langchain_community.llms import HuggingFaceEndpoint
23
  from auditqa.process_chunks import load_chunks, getconfig
24
  from langchain_community.chat_models.huggingface import ChatHuggingFace
 
 
 
25
  from qdrant_client.http import models as rest
26
  #from qdrant_client import QdrantClient
27
  from dotenv import load_dotenv
@@ -64,7 +68,7 @@ def save_logs(logs) -> None:
64
  with JSON_DATASET_PATH.open("a") as f:
65
  json.dump(logs, f)
66
  f.write("\n")
67
- print("logging done")
68
 
69
  def make_html_source(source,i):
70
  """
@@ -119,13 +123,13 @@ async def chat(query,history,sources,reports,subtype,year):
119
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
120
  (messages in gradio format, messages in langchain format, source documents)"""
121
 
122
- print(f">> NEW QUESTION : {query}")
123
- print(f"history:{history}")
124
  #print(f"audience:{audience}")
125
- print(f"sources:{sources}")
126
- print(f"reports:{reports}")
127
- print(f"subtype:{subtype}")
128
- print(f"year:{year}")
129
  docs_html = ""
130
  output_query = ""
131
 
@@ -137,7 +141,7 @@ async def chat(query,history,sources,reports,subtype,year):
137
 
138
  ###-------------------------------------Construct Filter------------------------------------
139
  if len(reports) == 0:
140
- print("defining filter for:",sources,":",subtype,":",year)
141
  filter=rest.Filter(
142
  must=[rest.FieldCondition(
143
  key="metadata.source",
@@ -167,11 +171,15 @@ async def chat(query,history,sources,reports,subtype,year):
167
  for question in question_lst:
168
  retriever = vectorstore.as_retriever(
169
  search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": int(model_config.get('retriever','TOP_K')), "filter":filter})
170
-
171
- context_retrieved = retriever.invoke(question)
172
- print(len(context_retrieved))
 
 
 
 
173
  for doc in context_retrieved:
174
- print(doc.metadata)
175
 
176
  def format_docs(docs):
177
  return "\n\n".join(doc.page_content for doc in docs)
@@ -261,7 +269,7 @@ async def chat(query,history,sources,reports,subtype,year):
261
  }
262
  save_logs(logs)
263
  except Exception as e:
264
- print(e)
265
 
266
  #process_pdf()
267
 
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import logging
4
  import numpy as np
5
  import os
6
  import time
 
23
  from langchain_community.llms import HuggingFaceEndpoint
24
  from auditqa.process_chunks import load_chunks, getconfig
25
  from langchain_community.chat_models.huggingface import ChatHuggingFace
26
+ from langchain.retrievers import ContextualCompressionRetriever
27
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
28
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
29
  from qdrant_client.http import models as rest
30
  #from qdrant_client import QdrantClient
31
  from dotenv import load_dotenv
 
68
  with JSON_DATASET_PATH.open("a") as f:
69
  json.dump(logs, f)
70
  f.write("\n")
71
+ logging.info("logging done")
72
 
73
  def make_html_source(source,i):
74
  """
 
123
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
124
  (messages in gradio format, messages in langchain format, source documents)"""
125
 
126
+ logging.info(f">> NEW QUESTION : {query}")
127
+ logging.info(f"history:{history}")
128
  #print(f"audience:{audience}")
129
+ logging.info(f"sources:{sources}")
130
+ logging.info(f"reports:{reports}")
131
+ logging.info(f"subtype:{subtype}")
132
+ logging.info(f"year:{year}")
133
  docs_html = ""
134
  output_query = ""
135
 
 
141
 
142
  ###-------------------------------------Construct Filter------------------------------------
143
  if len(reports) == 0:
144
+ ("defining filter for:{}:{}:{}".format(sources,subtype,year))
145
  filter=rest.Filter(
146
  must=[rest.FieldCondition(
147
  key="metadata.source",
 
171
  for question in question_lst:
172
  retriever = vectorstore.as_retriever(
173
  search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": int(model_config.get('retriever','TOP_K')), "filter":filter})
174
+ model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
175
+ compressor = CrossEncoderReranker(model=model, top_n=3)
176
+ compression_retriever = ContextualCompressionRetriever(
177
+ base_compressor=compressor, base_retriever=retriever
178
+ )
179
+ context_retrieved = compression_retriever.invoke(question)
180
+ logging.info(len(context_retrieved))
181
  for doc in context_retrieved:
182
+ logging.info(doc.metadata)
183
 
184
  def format_docs(docs):
185
  return "\n\n".join(doc.page_content for doc in docs)
 
269
  }
270
  save_logs(logs)
271
  except Exception as e:
272
+ logging.error(e)
273
 
274
  #process_pdf()
275