Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from qdrant_client.http import models as rest | |
from auditqa.process_chunks import getconfig | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
model_config = getconfig("model_params.cfg") | |
def create_filter(reports:list = [],sources:str =None, | |
subtype:str =None,year:str =None): | |
if len(reports) == 0: | |
print("defining filter for sources:{},subtype:{},year:{}".format(sources,subtype,year)) | |
filter=rest.Filter( | |
must=[rest.FieldCondition( | |
key="metadata.source", | |
match=rest.MatchValue(value=sources) | |
), | |
rest.FieldCondition( | |
key="metadata.subtype", | |
match=rest.MatchValue(value=subtype) | |
), | |
rest.FieldCondition( | |
key="metadata.year", | |
match=rest.MatchAny(any=year) | |
),]) | |
else: | |
print("defining filter for allreports:",reports) | |
filter=rest.Filter( | |
must=[ | |
rest.FieldCondition( | |
key="metadata.filename", | |
match=rest.MatchAny(any=reports) | |
)]) | |
return filter | |
def get_context(vectorstore,query,reports,sources,subtype,year): | |
# create metadata filter | |
filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year) | |
# getting context | |
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": 0.6, | |
"k": int(model_config.get('retriever','TOP_K')), | |
"filter":filter}) | |
# re-ranking the retrieved results | |
model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL')) | |
compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K'))) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, base_retriever=retriever | |
) | |
context_retrieved = compression_retriever.invoke(query) | |
print(f"retrieved paragraphs:{len(context_retrieved)}") | |
return context_retrieved |