Spaces:
Runtime error
Runtime error
updated for test storage module, plus prelim generalized approach to multi data source
Browse files- app/main.py +11 -2
- app/retriever.py +27 -62
- app/vectorstore_interface.py +89 -0
- params.cfg +9 -5
- requirements.txt +3 -1
app/main.py
CHANGED
@@ -1,5 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
-
from .retriever import retrieve_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# ---------------------------------------------------------------------
|
5 |
# Gradio Interface with MCP support
|
@@ -78,7 +87,7 @@ ui = gr.Interface(
|
|
78 |
if __name__ == "__main__":
|
79 |
ui.launch(
|
80 |
server_name="0.0.0.0",
|
81 |
-
server_port=7860,
|
82 |
mcp_server=True,
|
83 |
show_error=True
|
84 |
)
|
|
|
1 |
import gradio as gr
|
2 |
+
from .retriever import retrieve_context, get_vectorstore
|
3 |
+
|
4 |
+
# Initialize vector store at startup
|
5 |
+
print("Initializing vector store connection...")
|
6 |
+
try:
|
7 |
+
vectorstore = get_vectorstore()
|
8 |
+
print("Vector store connection initialized successfully")
|
9 |
+
except Exception as e:
|
10 |
+
print(f"Failed to initialize vector store: {e}")
|
11 |
+
raise
|
12 |
|
13 |
# ---------------------------------------------------------------------
|
14 |
# Gradio Interface with MCP support
|
|
|
87 |
if __name__ == "__main__":
|
88 |
ui.launch(
|
89 |
server_name="0.0.0.0",
|
90 |
+
server_port=7860, # Different port from reranker
|
91 |
mcp_server=True,
|
92 |
show_error=True
|
93 |
)
|
app/retriever.py
CHANGED
@@ -2,6 +2,7 @@ from typing import List, Dict, Any, Optional
|
|
2 |
from qdrant_client.http import models as rest
|
3 |
from langchain.schema import Document
|
4 |
from .utils import getconfig
|
|
|
5 |
import logging
|
6 |
|
7 |
# Load configuration
|
@@ -11,6 +12,20 @@ config = getconfig("params.cfg")
|
|
11 |
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
|
12 |
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def create_filter(
|
15 |
reports: List[str] = None,
|
16 |
sources: str = None,
|
@@ -74,37 +89,9 @@ def create_filter(
|
|
74 |
return rest.Filter(must=conditions)
|
75 |
return None
|
76 |
|
77 |
-
def get_vectorstore():
|
78 |
-
"""
|
79 |
-
Initialize and return the vectorstore connection.
|
80 |
-
This function should be implemented based on your specific vectorstore setup.
|
81 |
-
|
82 |
-
Returns:
|
83 |
-
Vectorstore instance (e.g., Qdrant, Pinecone, etc.)
|
84 |
-
"""
|
85 |
-
# TODO: Implement based on your external vector database
|
86 |
-
# Example for Qdrant:
|
87 |
-
# from langchain_community.vectorstores import Qdrant
|
88 |
-
# from qdrant_client import QdrantClient
|
89 |
-
#
|
90 |
-
# client = QdrantClient(
|
91 |
-
# host=config.get("vectorstore", "HOST"),
|
92 |
-
# port=config.get("vectorstore", "PORT"),
|
93 |
-
# api_key=config.get("vectorstore", "API_KEY", fallback=None)
|
94 |
-
# )
|
95 |
-
#
|
96 |
-
# vectorstore = Qdrant(
|
97 |
-
# client=client,
|
98 |
-
# collection_name=config.get("vectorstore", "COLLECTION_NAME"),
|
99 |
-
# embeddings=your_embedding_model # You'll need to configure this
|
100 |
-
# )
|
101 |
-
#
|
102 |
-
# return vectorstore
|
103 |
-
|
104 |
-
raise NotImplementedError("Please implement vectorstore connection based on your setup")
|
105 |
-
|
106 |
def retrieve_context(
|
107 |
query: str,
|
|
|
108 |
reports: List[str] = None,
|
109 |
sources: str = None,
|
110 |
subtype: str = None,
|
@@ -116,6 +103,7 @@ def retrieve_context(
|
|
116 |
|
117 |
Args:
|
118 |
query: The search query
|
|
|
119 |
reports: List of specific report filenames to search within
|
120 |
sources: Source type to filter by
|
121 |
subtype: Document subtype to filter by
|
@@ -126,48 +114,25 @@ def retrieve_context(
|
|
126 |
List of dictionaries with 'page_content' and 'metadata' keys
|
127 |
"""
|
128 |
try:
|
129 |
-
#
|
130 |
-
vectorstore = get_vectorstore()
|
131 |
-
|
132 |
-
# Create metadata filter
|
133 |
-
filter_obj = create_filter(
|
134 |
-
reports=reports or [],
|
135 |
-
sources=sources,
|
136 |
-
subtype=subtype,
|
137 |
-
year=year or []
|
138 |
-
)
|
139 |
-
|
140 |
-
# Set up search parameters
|
141 |
k = top_k or RETRIEVER_TOP_K
|
|
|
|
|
142 |
search_kwargs = {
|
143 |
-
"
|
144 |
-
"k": k
|
145 |
}
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
retriever = vectorstore.as_retriever(
|
152 |
-
search_type="similarity_score_threshold",
|
153 |
-
search_kwargs=search_kwargs
|
154 |
-
)
|
155 |
|
156 |
# Perform retrieval
|
157 |
-
retrieved_docs
|
158 |
|
159 |
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
160 |
|
161 |
-
|
162 |
-
results = [
|
163 |
-
{
|
164 |
-
"page_content": doc.page_content,
|
165 |
-
"metadata": doc.metadata
|
166 |
-
}
|
167 |
-
for doc in retrieved_docs
|
168 |
-
]
|
169 |
-
|
170 |
-
return results
|
171 |
|
172 |
except Exception as e:
|
173 |
logging.error(f"Error during retrieval: {str(e)}")
|
|
|
2 |
from qdrant_client.http import models as rest
|
3 |
from langchain.schema import Document
|
4 |
from .utils import getconfig
|
5 |
+
from .vectorstore_interface import create_vectorstore, VectorStoreInterface
|
6 |
import logging
|
7 |
|
8 |
# Load configuration
|
|
|
12 |
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
|
13 |
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
|
14 |
|
15 |
+
# Initialize vector store connection at module import time
|
16 |
+
logging.info("Initializing vector store connection...")
|
17 |
+
vectorstore = create_vectorstore(config)
|
18 |
+
logging.info("Vector store connection initialized successfully")
|
19 |
+
|
20 |
+
def get_vectorstore() -> VectorStoreInterface:
|
21 |
+
"""
|
22 |
+
Return the pre-initialized vector store connection.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
VectorStoreInterface instance
|
26 |
+
"""
|
27 |
+
return vectorstore
|
28 |
+
|
29 |
def create_filter(
|
30 |
reports: List[str] = None,
|
31 |
sources: str = None,
|
|
|
89 |
return rest.Filter(must=conditions)
|
90 |
return None
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def retrieve_context(
|
93 |
query: str,
|
94 |
+
vectorstore,
|
95 |
reports: List[str] = None,
|
96 |
sources: str = None,
|
97 |
subtype: str = None,
|
|
|
103 |
|
104 |
Args:
|
105 |
query: The search query
|
106 |
+
vectorstore: Pre-initialized vector store instance
|
107 |
reports: List of specific report filenames to search within
|
108 |
sources: Source type to filter by
|
109 |
subtype: Document subtype to filter by
|
|
|
114 |
List of dictionaries with 'page_content' and 'metadata' keys
|
115 |
"""
|
116 |
try:
|
117 |
+
# Use the passed vector store instead of calling get_vectorstore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
k = top_k or RETRIEVER_TOP_K
|
119 |
+
|
120 |
+
# For Hugging Face Spaces, we pass the model name from config
|
121 |
search_kwargs = {
|
122 |
+
"model_name": config.get("embeddings", "MODEL_NAME")
|
|
|
123 |
}
|
124 |
|
125 |
+
# Note: Filtering is currently limited for Hugging Face Spaces
|
126 |
+
# as the API doesn't expose filtering capabilities
|
127 |
+
if any([reports, sources, subtype, year]):
|
128 |
+
logging.warning("Filtering not supported for Hugging Face Spaces API")
|
|
|
|
|
|
|
|
|
129 |
|
130 |
# Perform retrieval
|
131 |
+
retrieved_docs = vectorstore.search(query, k, **search_kwargs)
|
132 |
|
133 |
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
134 |
|
135 |
+
return retrieved_docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
except Exception as e:
|
138 |
logging.error(f"Error during retrieval: {str(e)}")
|
app/vectorstore_interface.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
from gradio_client import Client
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
class VectorStoreInterface(ABC):
|
9 |
+
"""Abstract interface for different vector store implementations."""
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
13 |
+
"""Search for similar documents."""
|
14 |
+
pass
|
15 |
+
|
16 |
+
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
|
17 |
+
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
|
18 |
+
|
19 |
+
def __init__(self, space_url: str, collection_name: str, hf_token: Optional[str] = None):
|
20 |
+
token = os.getenv("HF_TOKEN")
|
21 |
+
repo_id = space_url
|
22 |
+
|
23 |
+
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
|
24 |
+
|
25 |
+
if token:
|
26 |
+
self.client = Client(repo_id, hf_token=token)
|
27 |
+
else:
|
28 |
+
self.client = Client(repo_id)
|
29 |
+
|
30 |
+
self.collection_name = collection_name
|
31 |
+
|
32 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
33 |
+
"""Search using Hugging Face Spaces MCP API."""
|
34 |
+
try:
|
35 |
+
# Use the /search_text endpoint as documented in the API
|
36 |
+
result = self.client.predict(
|
37 |
+
query=query,
|
38 |
+
collection_name=self.collection_name,
|
39 |
+
model_name=kwargs.get('model_name'),
|
40 |
+
top_k=top_k,
|
41 |
+
api_name="/search_text"
|
42 |
+
)
|
43 |
+
|
44 |
+
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
|
45 |
+
return result
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
|
49 |
+
raise e
|
50 |
+
|
51 |
+
# class QdrantVectorStore(VectorStoreInterface):
|
52 |
+
# """Vector store implementation for direct Qdrant connection."""
|
53 |
+
# # needs to be generalized for other vector stores (or add a new class for each vector store)
|
54 |
+
# def __init__(self, host: str, port: int, collection_name: str, api_key: Optional[str] = None):
|
55 |
+
# from qdrant_client import QdrantClient
|
56 |
+
# from langchain_community.vectorstores import Qdrant
|
57 |
+
|
58 |
+
# self.client = QdrantClient(
|
59 |
+
# host=host,
|
60 |
+
# port=port,
|
61 |
+
# api_key=api_key
|
62 |
+
# )
|
63 |
+
# self.collection_name = collection_name
|
64 |
+
# # Embedding model not implemented
|
65 |
+
|
66 |
+
# def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
67 |
+
# """Search using direct Qdrant connection."""
|
68 |
+
# # Embedding model not implemented
|
69 |
+
# raise NotImplementedError("Direct Qdrant search needs embedding model configuration")
|
70 |
+
|
71 |
+
def create_vectorstore(config: Any) -> VectorStoreInterface:
|
72 |
+
"""Factory function to create appropriate vector store based on configuration."""
|
73 |
+
vectorstore_type = config.get("vectorstore", "TYPE")
|
74 |
+
|
75 |
+
if vectorstore_type.lower() == "huggingface_spaces":
|
76 |
+
space_url = config.get("vectorstore", "SPACE_URL")
|
77 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
78 |
+
hf_token = config.get("vectorstore", "HF_TOKEN", fallback=None)
|
79 |
+
return HuggingFaceSpacesVectorStore(space_url, collection_name, hf_token)
|
80 |
+
|
81 |
+
elif vectorstore_type.lower() == "qdrant":
|
82 |
+
host = config.get("vectorstore", "HOST")
|
83 |
+
port = int(config.get("vectorstore", "PORT"))
|
84 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
85 |
+
api_key = config.get("vectorstore", "API_KEY", fallback=None)
|
86 |
+
return QdrantVectorStore(host, port, collection_name, api_key)
|
87 |
+
|
88 |
+
else:
|
89 |
+
raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
|
params.cfg
CHANGED
@@ -3,11 +3,15 @@ TOP_K = 10
|
|
3 |
SCORE_THRESHOLD = 0.6
|
4 |
|
5 |
[vectorstore]
|
6 |
-
TYPE =
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
#
|
|
|
|
|
|
|
|
|
11 |
|
12 |
[embeddings]
|
13 |
MODEL_NAME = BAAI/bge-m3
|
|
|
3 |
SCORE_THRESHOLD = 0.6
|
4 |
|
5 |
[vectorstore]
|
6 |
+
TYPE = huggingface_spaces
|
7 |
+
SPACE_URL = GIZ/audit_data
|
8 |
+
COLLECTION_NAME = docling
|
9 |
+
# For future direct Qdrant usage:
|
10 |
+
# TYPE = qdrant
|
11 |
+
# HOST = ip address
|
12 |
+
# PORT = 6333
|
13 |
+
# COLLECTION_NAME = "collection name"
|
14 |
+
# API_KEY = api key for source
|
15 |
|
16 |
[embeddings]
|
17 |
MODEL_NAME = BAAI/bge-m3
|
requirements.txt
CHANGED
@@ -2,4 +2,6 @@ gradio[mcp]
|
|
2 |
langchain
|
3 |
langchain-community
|
4 |
qdrant-client
|
5 |
-
sentence-transformers
|
|
|
|
|
|
2 |
langchain
|
3 |
langchain-community
|
4 |
qdrant-client
|
5 |
+
sentence-transformers
|
6 |
+
gradio_client>=0.10.0
|
7 |
+
huggingface_hub>=0.20.0
|