mtyrrell commited on
Commit
08a352f
·
1 Parent(s): 87abc73

updated for test storage module, plus prelim generalized approach to multi data source

Browse files
Files changed (5) hide show
  1. app/main.py +11 -2
  2. app/retriever.py +27 -62
  3. app/vectorstore_interface.py +89 -0
  4. params.cfg +9 -5
  5. requirements.txt +3 -1
app/main.py CHANGED
@@ -1,5 +1,14 @@
1
  import gradio as gr
2
- from .retriever import retrieve_context
 
 
 
 
 
 
 
 
 
3
 
4
  # ---------------------------------------------------------------------
5
  # Gradio Interface with MCP support
@@ -78,7 +87,7 @@ ui = gr.Interface(
78
  if __name__ == "__main__":
79
  ui.launch(
80
  server_name="0.0.0.0",
81
- server_port=7860,
82
  mcp_server=True,
83
  show_error=True
84
  )
 
1
  import gradio as gr
2
+ from .retriever import retrieve_context, get_vectorstore
3
+
4
+ # Initialize vector store at startup
5
+ print("Initializing vector store connection...")
6
+ try:
7
+ vectorstore = get_vectorstore()
8
+ print("Vector store connection initialized successfully")
9
+ except Exception as e:
10
+ print(f"Failed to initialize vector store: {e}")
11
+ raise
12
 
13
  # ---------------------------------------------------------------------
14
  # Gradio Interface with MCP support
 
87
  if __name__ == "__main__":
88
  ui.launch(
89
  server_name="0.0.0.0",
90
+ server_port=7860, # Different port from reranker
91
  mcp_server=True,
92
  show_error=True
93
  )
app/retriever.py CHANGED
@@ -2,6 +2,7 @@ from typing import List, Dict, Any, Optional
2
  from qdrant_client.http import models as rest
3
  from langchain.schema import Document
4
  from .utils import getconfig
 
5
  import logging
6
 
7
  # Load configuration
@@ -11,6 +12,20 @@ config = getconfig("params.cfg")
11
  RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
12
  SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def create_filter(
15
  reports: List[str] = None,
16
  sources: str = None,
@@ -74,37 +89,9 @@ def create_filter(
74
  return rest.Filter(must=conditions)
75
  return None
76
 
77
- def get_vectorstore():
78
- """
79
- Initialize and return the vectorstore connection.
80
- This function should be implemented based on your specific vectorstore setup.
81
-
82
- Returns:
83
- Vectorstore instance (e.g., Qdrant, Pinecone, etc.)
84
- """
85
- # TODO: Implement based on your external vector database
86
- # Example for Qdrant:
87
- # from langchain_community.vectorstores import Qdrant
88
- # from qdrant_client import QdrantClient
89
- #
90
- # client = QdrantClient(
91
- # host=config.get("vectorstore", "HOST"),
92
- # port=config.get("vectorstore", "PORT"),
93
- # api_key=config.get("vectorstore", "API_KEY", fallback=None)
94
- # )
95
- #
96
- # vectorstore = Qdrant(
97
- # client=client,
98
- # collection_name=config.get("vectorstore", "COLLECTION_NAME"),
99
- # embeddings=your_embedding_model # You'll need to configure this
100
- # )
101
- #
102
- # return vectorstore
103
-
104
- raise NotImplementedError("Please implement vectorstore connection based on your setup")
105
-
106
  def retrieve_context(
107
  query: str,
 
108
  reports: List[str] = None,
109
  sources: str = None,
110
  subtype: str = None,
@@ -116,6 +103,7 @@ def retrieve_context(
116
 
117
  Args:
118
  query: The search query
 
119
  reports: List of specific report filenames to search within
120
  sources: Source type to filter by
121
  subtype: Document subtype to filter by
@@ -126,48 +114,25 @@ def retrieve_context(
126
  List of dictionaries with 'page_content' and 'metadata' keys
127
  """
128
  try:
129
- # Get vectorstore instance
130
- vectorstore = get_vectorstore()
131
-
132
- # Create metadata filter
133
- filter_obj = create_filter(
134
- reports=reports or [],
135
- sources=sources,
136
- subtype=subtype,
137
- year=year or []
138
- )
139
-
140
- # Set up search parameters
141
  k = top_k or RETRIEVER_TOP_K
 
 
142
  search_kwargs = {
143
- "score_threshold": SCORE_THRESHOLD,
144
- "k": k
145
  }
146
 
147
- if filter_obj:
148
- search_kwargs["filter"] = filter_obj
149
-
150
- # Create retriever
151
- retriever = vectorstore.as_retriever(
152
- search_type="similarity_score_threshold",
153
- search_kwargs=search_kwargs
154
- )
155
 
156
  # Perform retrieval
157
- retrieved_docs: List[Document] = retriever.invoke(query)
158
 
159
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
160
 
161
- # Convert to dictionary format
162
- results = [
163
- {
164
- "page_content": doc.page_content,
165
- "metadata": doc.metadata
166
- }
167
- for doc in retrieved_docs
168
- ]
169
-
170
- return results
171
 
172
  except Exception as e:
173
  logging.error(f"Error during retrieval: {str(e)}")
 
2
  from qdrant_client.http import models as rest
3
  from langchain.schema import Document
4
  from .utils import getconfig
5
+ from .vectorstore_interface import create_vectorstore, VectorStoreInterface
6
  import logging
7
 
8
  # Load configuration
 
12
  RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
13
  SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
14
 
15
+ # Initialize vector store connection at module import time
16
+ logging.info("Initializing vector store connection...")
17
+ vectorstore = create_vectorstore(config)
18
+ logging.info("Vector store connection initialized successfully")
19
+
20
+ def get_vectorstore() -> VectorStoreInterface:
21
+ """
22
+ Return the pre-initialized vector store connection.
23
+
24
+ Returns:
25
+ VectorStoreInterface instance
26
+ """
27
+ return vectorstore
28
+
29
  def create_filter(
30
  reports: List[str] = None,
31
  sources: str = None,
 
89
  return rest.Filter(must=conditions)
90
  return None
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def retrieve_context(
93
  query: str,
94
+ vectorstore,
95
  reports: List[str] = None,
96
  sources: str = None,
97
  subtype: str = None,
 
103
 
104
  Args:
105
  query: The search query
106
+ vectorstore: Pre-initialized vector store instance
107
  reports: List of specific report filenames to search within
108
  sources: Source type to filter by
109
  subtype: Document subtype to filter by
 
114
  List of dictionaries with 'page_content' and 'metadata' keys
115
  """
116
  try:
117
+ # Use the passed vector store instead of calling get_vectorstore()
 
 
 
 
 
 
 
 
 
 
 
118
  k = top_k or RETRIEVER_TOP_K
119
+
120
+ # For Hugging Face Spaces, we pass the model name from config
121
  search_kwargs = {
122
+ "model_name": config.get("embeddings", "MODEL_NAME")
 
123
  }
124
 
125
+ # Note: Filtering is currently limited for Hugging Face Spaces
126
+ # as the API doesn't expose filtering capabilities
127
+ if any([reports, sources, subtype, year]):
128
+ logging.warning("Filtering not supported for Hugging Face Spaces API")
 
 
 
 
129
 
130
  # Perform retrieval
131
+ retrieved_docs = vectorstore.search(query, k, **search_kwargs)
132
 
133
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
134
 
135
+ return retrieved_docs
 
 
 
 
 
 
 
 
 
136
 
137
  except Exception as e:
138
  logging.error(f"Error during retrieval: {str(e)}")
app/vectorstore_interface.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any, Optional
3
+ from gradio_client import Client
4
+ import logging
5
+ import os
6
+ import time
7
+
8
+ class VectorStoreInterface(ABC):
9
+ """Abstract interface for different vector store implementations."""
10
+
11
+ @abstractmethod
12
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
13
+ """Search for similar documents."""
14
+ pass
15
+
16
+ class HuggingFaceSpacesVectorStore(VectorStoreInterface):
17
+ """Vector store implementation for Hugging Face Spaces with MCP endpoints."""
18
+
19
+ def __init__(self, space_url: str, collection_name: str, hf_token: Optional[str] = None):
20
+ token = os.getenv("HF_TOKEN")
21
+ repo_id = space_url
22
+
23
+ logging.info(f"Connecting to Hugging Face Space: {repo_id}")
24
+
25
+ if token:
26
+ self.client = Client(repo_id, hf_token=token)
27
+ else:
28
+ self.client = Client(repo_id)
29
+
30
+ self.collection_name = collection_name
31
+
32
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
33
+ """Search using Hugging Face Spaces MCP API."""
34
+ try:
35
+ # Use the /search_text endpoint as documented in the API
36
+ result = self.client.predict(
37
+ query=query,
38
+ collection_name=self.collection_name,
39
+ model_name=kwargs.get('model_name'),
40
+ top_k=top_k,
41
+ api_name="/search_text"
42
+ )
43
+
44
+ logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
45
+ return result
46
+
47
+ except Exception as e:
48
+ logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
49
+ raise e
50
+
51
+ # class QdrantVectorStore(VectorStoreInterface):
52
+ # """Vector store implementation for direct Qdrant connection."""
53
+ # # needs to be generalized for other vector stores (or add a new class for each vector store)
54
+ # def __init__(self, host: str, port: int, collection_name: str, api_key: Optional[str] = None):
55
+ # from qdrant_client import QdrantClient
56
+ # from langchain_community.vectorstores import Qdrant
57
+
58
+ # self.client = QdrantClient(
59
+ # host=host,
60
+ # port=port,
61
+ # api_key=api_key
62
+ # )
63
+ # self.collection_name = collection_name
64
+ # # Embedding model not implemented
65
+
66
+ # def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
67
+ # """Search using direct Qdrant connection."""
68
+ # # Embedding model not implemented
69
+ # raise NotImplementedError("Direct Qdrant search needs embedding model configuration")
70
+
71
+ def create_vectorstore(config: Any) -> VectorStoreInterface:
72
+ """Factory function to create appropriate vector store based on configuration."""
73
+ vectorstore_type = config.get("vectorstore", "TYPE")
74
+
75
+ if vectorstore_type.lower() == "huggingface_spaces":
76
+ space_url = config.get("vectorstore", "SPACE_URL")
77
+ collection_name = config.get("vectorstore", "COLLECTION_NAME")
78
+ hf_token = config.get("vectorstore", "HF_TOKEN", fallback=None)
79
+ return HuggingFaceSpacesVectorStore(space_url, collection_name, hf_token)
80
+
81
+ elif vectorstore_type.lower() == "qdrant":
82
+ host = config.get("vectorstore", "HOST")
83
+ port = int(config.get("vectorstore", "PORT"))
84
+ collection_name = config.get("vectorstore", "COLLECTION_NAME")
85
+ api_key = config.get("vectorstore", "API_KEY", fallback=None)
86
+ return QdrantVectorStore(host, port, collection_name, api_key)
87
+
88
+ else:
89
+ raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
params.cfg CHANGED
@@ -3,11 +3,15 @@ TOP_K = 10
3
  SCORE_THRESHOLD = 0.6
4
 
5
  [vectorstore]
6
- TYPE = qdrant
7
- HOST = localhost
8
- PORT = 6333
9
- COLLECTION_NAME = "auditqa"
10
- # API_KEY = your_api_key_if_needed
 
 
 
 
11
 
12
  [embeddings]
13
  MODEL_NAME = BAAI/bge-m3
 
3
  SCORE_THRESHOLD = 0.6
4
 
5
  [vectorstore]
6
+ TYPE = huggingface_spaces
7
+ SPACE_URL = GIZ/audit_data
8
+ COLLECTION_NAME = docling
9
+ # For future direct Qdrant usage:
10
+ # TYPE = qdrant
11
+ # HOST = ip address
12
+ # PORT = 6333
13
+ # COLLECTION_NAME = "collection name"
14
+ # API_KEY = api key for source
15
 
16
  [embeddings]
17
  MODEL_NAME = BAAI/bge-m3
requirements.txt CHANGED
@@ -2,4 +2,6 @@ gradio[mcp]
2
  langchain
3
  langchain-community
4
  qdrant-client
5
- sentence-transformers
 
 
 
2
  langchain
3
  langchain-community
4
  qdrant-client
5
+ sentence-transformers
6
+ gradio_client>=0.10.0
7
+ huggingface_hub>=0.20.0