from qdrant_client.http import models as rest from auditqa.process_chunks import getconfig from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_community.cross_encoders import HuggingFaceCrossEncoder model_config = getconfig("model_params.cfg") def create_filter(reports:list = [],sources:str =None, subtype:str =None,year:str =None): if len(reports) == 0: print("defining filter for sources:{},subtype:{},year:{}".format(sources,subtype,year)) filter=rest.Filter( must=[rest.FieldCondition( key="metadata.source", match=rest.MatchValue(value=sources) ), rest.FieldCondition( key="metadata.subtype", match=rest.MatchValue(value=subtype) ), rest.FieldCondition( key="metadata.year", match=rest.MatchAny(any=year) ),]) else: print("defining filter for allreports:",reports) filter=rest.Filter( must=[ rest.FieldCondition( key="metadata.filename", match=rest.MatchAny(any=reports) )]) return filter def get_context(vectorstore,query,reports,sources,subtype,year): # create metadata filter filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year) # getting context retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": int(model_config.get('retriever','TOP_K')), "filter":filter}) # re-ranking the retrieved results model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL')) compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K'))) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) context_retrieved = compression_retriever.invoke(query) print(f"retrieved paragraphs:{len(context_retrieved)}") return context_retrieved