Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
|
|
3 |
import numpy as np
|
4 |
import os
|
5 |
import time
|
@@ -22,6 +23,9 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
|
|
22 |
from langchain_community.llms import HuggingFaceEndpoint
|
23 |
from auditqa.process_chunks import load_chunks, getconfig
|
24 |
from langchain_community.chat_models.huggingface import ChatHuggingFace
|
|
|
|
|
|
|
25 |
from qdrant_client.http import models as rest
|
26 |
#from qdrant_client import QdrantClient
|
27 |
from dotenv import load_dotenv
|
@@ -64,7 +68,7 @@ def save_logs(logs) -> None:
|
|
64 |
with JSON_DATASET_PATH.open("a") as f:
|
65 |
json.dump(logs, f)
|
66 |
f.write("\n")
|
67 |
-
|
68 |
|
69 |
def make_html_source(source,i):
|
70 |
"""
|
@@ -119,13 +123,13 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
119 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
120 |
(messages in gradio format, messages in langchain format, source documents)"""
|
121 |
|
122 |
-
|
123 |
-
|
124 |
#print(f"audience:{audience}")
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
docs_html = ""
|
130 |
output_query = ""
|
131 |
|
@@ -137,7 +141,7 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
137 |
|
138 |
###-------------------------------------Construct Filter------------------------------------
|
139 |
if len(reports) == 0:
|
140 |
-
|
141 |
filter=rest.Filter(
|
142 |
must=[rest.FieldCondition(
|
143 |
key="metadata.source",
|
@@ -167,11 +171,15 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
167 |
for question in question_lst:
|
168 |
retriever = vectorstore.as_retriever(
|
169 |
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": int(model_config.get('retriever','TOP_K')), "filter":filter})
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
for doc in context_retrieved:
|
174 |
-
|
175 |
|
176 |
def format_docs(docs):
|
177 |
return "\n\n".join(doc.page_content for doc in docs)
|
@@ -261,7 +269,7 @@ async def chat(query,history,sources,reports,subtype,year):
|
|
261 |
}
|
262 |
save_logs(logs)
|
263 |
except Exception as e:
|
264 |
-
|
265 |
|
266 |
#process_pdf()
|
267 |
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
+
import logging
|
4 |
import numpy as np
|
5 |
import os
|
6 |
import time
|
|
|
23 |
from langchain_community.llms import HuggingFaceEndpoint
|
24 |
from auditqa.process_chunks import load_chunks, getconfig
|
25 |
from langchain_community.chat_models.huggingface import ChatHuggingFace
|
26 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
27 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
28 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
29 |
from qdrant_client.http import models as rest
|
30 |
#from qdrant_client import QdrantClient
|
31 |
from dotenv import load_dotenv
|
|
|
68 |
with JSON_DATASET_PATH.open("a") as f:
|
69 |
json.dump(logs, f)
|
70 |
f.write("\n")
|
71 |
+
logging.info("logging done")
|
72 |
|
73 |
def make_html_source(source,i):
|
74 |
"""
|
|
|
123 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
124 |
(messages in gradio format, messages in langchain format, source documents)"""
|
125 |
|
126 |
+
logging.info(f">> NEW QUESTION : {query}")
|
127 |
+
logging.info(f"history:{history}")
|
128 |
#print(f"audience:{audience}")
|
129 |
+
logging.info(f"sources:{sources}")
|
130 |
+
logging.info(f"reports:{reports}")
|
131 |
+
logging.info(f"subtype:{subtype}")
|
132 |
+
logging.info(f"year:{year}")
|
133 |
docs_html = ""
|
134 |
output_query = ""
|
135 |
|
|
|
141 |
|
142 |
###-------------------------------------Construct Filter------------------------------------
|
143 |
if len(reports) == 0:
|
144 |
+
("defining filter for:{}:{}:{}".format(sources,subtype,year))
|
145 |
filter=rest.Filter(
|
146 |
must=[rest.FieldCondition(
|
147 |
key="metadata.source",
|
|
|
171 |
for question in question_lst:
|
172 |
retriever = vectorstore.as_retriever(
|
173 |
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": int(model_config.get('retriever','TOP_K')), "filter":filter})
|
174 |
+
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
175 |
+
compressor = CrossEncoderReranker(model=model, top_n=3)
|
176 |
+
compression_retriever = ContextualCompressionRetriever(
|
177 |
+
base_compressor=compressor, base_retriever=retriever
|
178 |
+
)
|
179 |
+
context_retrieved = compression_retriever.invoke(question)
|
180 |
+
logging.info(len(context_retrieved))
|
181 |
for doc in context_retrieved:
|
182 |
+
logging.info(doc.metadata)
|
183 |
|
184 |
def format_docs(docs):
|
185 |
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
269 |
}
|
270 |
save_logs(logs)
|
271 |
except Exception as e:
|
272 |
+
logging.error(e)
|
273 |
|
274 |
#process_pdf()
|
275 |
|