mtyrrell commited on
Commit
ec32e84
·
1 Parent(s): a38e3e8

reranker integration (optional)

Browse files
Files changed (4) hide show
  1. app/main.py +5 -3
  2. app/retriever.py +108 -22
  3. params.cfg +13 -5
  4. requirements.txt +2 -1
app/main.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from .retriever import retrieve_context, get_vectorstore
3
 
4
  # Initialize vector store at startup
5
  print("Initializing vector store connection...")
@@ -40,7 +40,8 @@ def retrieve_mcp(
40
  year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
41
 
42
  # Call retriever function and return raw results
43
- results = retrieve_context(
 
44
  query=query,
45
  reports=reports,
46
  sources=sources,
@@ -64,7 +65,8 @@ def retrieve_ui(query, reports_filter="", sources_filter="", subtype_filter="",
64
  year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
65
 
66
  # Call retriever function
67
- results = retrieve_context(
 
68
  query=query,
69
  reports=reports,
70
  sources=sources,
 
1
  import gradio as gr
2
+ from .retriever import get_context, get_vectorstore
3
 
4
  # Initialize vector store at startup
5
  print("Initializing vector store connection...")
 
40
  year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
41
 
42
  # Call retriever function and return raw results
43
+ results = get_context(
44
+ vectorstore=vectorstore,
45
  query=query,
46
  reports=reports,
47
  sources=sources,
 
65
  year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
66
 
67
  # Call retriever function
68
+ results = get_context(
69
+ vectorstore=vectorstore,
70
  query=query,
71
  reports=reports,
72
  sources=sources,
app/retriever.py CHANGED
@@ -1,6 +1,8 @@
1
  from typing import List, Dict, Any, Optional
2
  from qdrant_client.http import models as rest
3
  from langchain.schema import Document
 
 
4
  from .utils import getconfig
5
  from .vectorstore_interface import create_vectorstore, VectorStoreInterface
6
  import logging
@@ -12,18 +14,39 @@ config = getconfig("params.cfg")
12
  RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
13
  SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
14
 
15
- # Initialize vector store connection at module import time
16
- logging.info("Initializing vector store connection...")
17
- vectorstore = create_vectorstore(config)
18
- logging.info("Vector store connection initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def get_vectorstore() -> VectorStoreInterface:
21
  """
22
- Return the pre-initialized vector store connection.
23
 
24
  Returns:
25
  VectorStoreInterface instance
26
  """
 
 
 
27
  return vectorstore
28
 
29
  def create_filter(
@@ -89,48 +112,111 @@ def create_filter(
89
  return rest.Filter(must=conditions)
90
  return None
91
 
92
- def retrieve_context(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  query: str,
94
  reports: List[str] = None,
95
  sources: str = None,
96
  subtype: str = None,
97
- year: List[str] = None,
98
- top_k: int = None
99
  ) -> List[Dict[str, Any]]:
100
  """
101
- Retrieve semantically similar documents from the vector database.
102
 
103
  Args:
 
104
  query: The search query
105
- vectorstore: Pre-initialized vector store instance
106
  reports: List of specific report filenames to search within
107
  sources: Source type to filter by
108
  subtype: Document subtype to filter by
109
  year: List of years to filter by
110
- top_k: Number of results to return (defaults to config value)
111
 
112
  Returns:
113
- List of dictionaries with 'page_content' and 'metadata' keys
114
  """
115
  try:
116
- # Use the passed vector store instead of calling get_vectorstore()
117
- k = top_k or RETRIEVER_TOP_K
 
 
 
118
 
119
- # For Hugging Face Spaces, we pass the model name from config
120
  search_kwargs = {
121
  "model_name": config.get("embeddings", "MODEL_NAME")
122
  }
123
 
124
- # Note: Filtering is currently limited for Hugging Face Spaces
125
- # as the API doesn't expose filtering capabilities
126
- if any([reports, sources, subtype, year]):
127
- logging.warning("Filtering not supported for Hugging Face Spaces API")
128
-
129
- # Perform retrieval
130
- retrieved_docs = vectorstore.search(query, k, **search_kwargs)
131
 
132
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
133
 
 
 
 
 
 
 
 
 
 
134
  return retrieved_docs
135
 
136
  except Exception as e:
 
1
  from typing import List, Dict, Any, Optional
2
  from qdrant_client.http import models as rest
3
  from langchain.schema import Document
4
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
6
  from .utils import getconfig
7
  from .vectorstore_interface import create_vectorstore, VectorStoreInterface
8
  import logging
 
14
  RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
15
  SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
16
 
17
+ # Reranker settings from config
18
+ RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False)
19
+ RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2")
20
+ RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5))
21
+ RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2))
22
+
23
+ # # Initialize vector store connection at module import time
24
+ # logging.info("Initializing vector store connection...")
25
+ # vectorstore = create_vectorstore(config)
26
+ # logging.info("Vector store connection initialized successfully")
27
+
28
+ # Initialize reranker if enabled
29
+ reranker = None
30
+ if RERANKER_ENABLED:
31
+ try:
32
+ logging.info(f"Initializing reranker with model: {RERANKER_MODEL}")
33
+ model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL)
34
+ reranker = CrossEncoderReranker(model=model, top_n=RERANKER_TOP_K)
35
+ logging.info("Reranker initialized successfully")
36
+ except Exception as e:
37
+ logging.error(f"Failed to initialize reranker: {str(e)}")
38
+ reranker = None
39
 
40
  def get_vectorstore() -> VectorStoreInterface:
41
  """
42
+ Create and return a vector store connection.
43
 
44
  Returns:
45
  VectorStoreInterface instance
46
  """
47
+ logging.info("Initializing vector store connection...")
48
+ vectorstore = create_vectorstore(config)
49
+ logging.info("Vector store connection initialized successfully")
50
  return vectorstore
51
 
52
  def create_filter(
 
112
  return rest.Filter(must=conditions)
113
  return None
114
 
115
+ def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
116
+ """
117
+ Rerank documents using cross-encoder (specify in params.cfg)
118
+
119
+ Args:
120
+ query: The search query
121
+ documents: List of documents to rerank
122
+
123
+ Returns:
124
+ Reranked list of documents in original format
125
+ """
126
+ if not reranker or not documents:
127
+ return documents
128
+
129
+ try:
130
+ logging.info(f"Starting reranking of {len(documents)} documents")
131
+
132
+ # Convert to LangChain Document format using correct keys (need to review this later for portability)
133
+ langchain_docs = []
134
+ for doc in documents:
135
+ # Use correct keys from the data storage test module
136
+ content = doc.get('answer', '')
137
+ metadata = doc.get('answer_metadata', {})
138
+
139
+ if not content:
140
+ logging.warning(f"Document missing content: {doc}")
141
+ continue
142
+
143
+ langchain_doc = Document(
144
+ page_content=content,
145
+ metadata=metadata
146
+ )
147
+ langchain_docs.append(langchain_doc)
148
+
149
+ if not langchain_docs:
150
+ logging.warning("No valid documents found for reranking")
151
+ return documents
152
+
153
+ # Rerank documents
154
+ logging.info(f"Reranking {len(langchain_docs)} documents")
155
+ reranked_docs = reranker.compress_documents(langchain_docs, query)
156
+
157
+ # Convert back to original format
158
+ result = []
159
+ for doc in reranked_docs:
160
+ result.append({
161
+ 'answer': doc.page_content,
162
+ 'answer_metadata': doc.metadata,
163
+ })
164
+
165
+ logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}")
166
+ return result
167
+
168
+ except Exception as e:
169
+ logging.error(f"Error during reranking: {str(e)}")
170
+ # Return original documents if reranking fails
171
+ return documents
172
+
173
+ def get_context(
174
+ vectorstore: VectorStoreInterface,
175
  query: str,
176
  reports: List[str] = None,
177
  sources: str = None,
178
  subtype: str = None,
179
+ year: List[str] = None
 
180
  ) -> List[Dict[str, Any]]:
181
  """
182
+ Retrieve semantically similar documents from the vector database with optional reranking.
183
 
184
  Args:
185
+ vectorstore: The vector store interface to search
186
  query: The search query
 
187
  reports: List of specific report filenames to search within
188
  sources: Source type to filter by
189
  subtype: Document subtype to filter by
190
  year: List of years to filter by
 
191
 
192
  Returns:
193
+ List of dictionaries with 'answer', 'answer_metadata', and 'score' keys
194
  """
195
  try:
196
+ # Use a higher k for initial retrieval if reranking is enabled (more candidates docs)
197
+ top_k = RETRIEVER_TOP_K
198
+ if RERANKER_ENABLED and reranker:
199
+ top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR
200
+ logging.info(f"Reranking enabled, retrieving {top_k} candidates")
201
 
 
202
  search_kwargs = {
203
  "model_name": config.get("embeddings", "MODEL_NAME")
204
  }
205
 
206
+ # Perform initial retrieval
207
+ retrieved_docs = vectorstore.search(query, top_k, **search_kwargs)
 
 
 
 
 
208
 
209
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
210
 
211
+ # Apply reranking if enabled
212
+ if RERANKER_ENABLED and reranker and retrieved_docs:
213
+ logging.info("Applying reranking...")
214
+ retrieved_docs = rerank_documents(query, retrieved_docs)
215
+
216
+ # Trim to final desired k
217
+ retrieved_docs = retrieved_docs[:RERANKER_TOP_K]
218
+
219
+ logging.info(f"Returning {len(retrieved_docs)} final documents")
220
  return retrieved_docs
221
 
222
  except Exception as e:
params.cfg CHANGED
@@ -1,7 +1,3 @@
1
- [retriever]
2
- TOP_K = 10
3
- SCORE_THRESHOLD = 0.6
4
-
5
  [vectorstore]
6
  TYPE = huggingface_spaces
7
  SPACE_URL = GIZ/audit_data
@@ -15,4 +11,16 @@ COLLECTION_NAME = docling
15
 
16
  [embeddings]
17
  MODEL_NAME = BAAI/bge-m3
18
- # DEVICE = cpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  [vectorstore]
2
  TYPE = huggingface_spaces
3
  SPACE_URL = GIZ/audit_data
 
11
 
12
  [embeddings]
13
  MODEL_NAME = BAAI/bge-m3
14
+ # DEVICE = cpu
15
+
16
+ [retriever]
17
+ TOP_K = 10
18
+ SCORE_THRESHOLD = 0.6
19
+
20
+ [reranker]
21
+ MODEL_NAME = cross-encoder/ms-marco-MiniLM-L-6-v2
22
+ TOP_K = 5
23
+ ENABLED = true
24
+ # use this to scale out the total docs retrieved prior to reranking (i.e. retriever top_k * TOP_K_SCALE_FACTOR)
25
+ TOP_K_SCALE_FACTOR = 2
26
+
requirements.txt CHANGED
@@ -4,4 +4,5 @@ langchain-community
4
  qdrant-client
5
  sentence-transformers
6
  gradio_client>=0.10.0
7
- huggingface_hub>=0.20.0
 
 
4
  qdrant-client
5
  sentence-transformers
6
  gradio_client>=0.10.0
7
+ huggingface_hub>=0.20.0
8
+ torch