File size: 2,484 Bytes
71aaf00
 
 
 
 
 
 
 
 
 
 
b3ec1fd
71aaf00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from qdrant_client.http import models as rest
from auditqa.process_chunks import getconfig
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model_config = getconfig("model_params.cfg")

def create_filter(reports:list = [],sources:str =None,
                  subtype:str =None,year:str =None):
    if len(reports) == 0:
        print("defining filter for sources:{},subtype:{},year:{}".format(sources,subtype,year))
        filter=rest.Filter(
                must=[rest.FieldCondition(
                        key="metadata.source",
                        match=rest.MatchValue(value=sources)
                    ),
                    rest.FieldCondition(
                        key="metadata.subtype",
                        match=rest.MatchValue(value=subtype)
                    ),
                    rest.FieldCondition(
                        key="metadata.year",
                        match=rest.MatchAny(any=year)
                    ),])
    else:
        print("defining filter for allreports:",reports)
        filter=rest.Filter(
                must=[
                    rest.FieldCondition(
                        key="metadata.filename",
                        match=rest.MatchAny(any=reports)
                    )])
    
    return filter


def get_context(vectorstore,query,reports,sources,subtype,year):
    # create metadata filter
    filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)

    # getting context
    retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", 
                                         search_kwargs={"score_threshold": 0.6, 
                                                        "k": int(model_config.get('retriever','TOP_K')), 
                                                        "filter":filter})
    # re-ranking the retrieved results
    model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
    compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
    compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=retriever
        )
    context_retrieved = compression_retriever.invoke(query)
    print(f"retrieved paragraphs:{len(context_retrieved)}")

    return context_retrieved