Spaces:
Runtime error
Runtime error
Commit
·
eb5a3fb
1
Parent(s):
ecd2385
Adding query expansion and reranker
Browse files- main/api.py +94 -2
- main/prompt_templates/chunk_rerank.json +39 -0
- main/prompt_templates/query_expansion.json +25 -0
- main/routes.py +38 -1
- main/schemas.py +52 -2
- main/utils.py +28 -0
main/api.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
|
|
|
|
1 |
import httpx
|
2 |
-
from typing import Optional, AsyncIterator, Dict, Any, Iterator, List
|
3 |
import logging
|
4 |
import asyncio
|
5 |
-
import os
|
6 |
from litserve import LitAPI
|
7 |
from pydantic import BaseModel
|
|
|
8 |
|
9 |
|
10 |
class GenerationResponse(BaseModel):
|
@@ -136,6 +139,95 @@ class InferenceApi(LitAPI):
|
|
136 |
self.logger.error(f"Error in generate_response: {str(e)}")
|
137 |
raise
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
async def generate_stream(
|
140 |
self,
|
141 |
prompt: str,
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
import httpx
|
5 |
+
from typing import Optional, AsyncIterator, Dict, Any, Iterator, List, Callable
|
6 |
import logging
|
7 |
import asyncio
|
|
|
8 |
from litserve import LitAPI
|
9 |
from pydantic import BaseModel
|
10 |
+
from .utils import extract_json
|
11 |
|
12 |
|
13 |
class GenerationResponse(BaseModel):
|
|
|
139 |
self.logger.error(f"Error in generate_response: {str(e)}")
|
140 |
raise
|
141 |
|
142 |
+
async def structured_llm_query(
|
143 |
+
self,
|
144 |
+
template_name: str,
|
145 |
+
input_text: str,
|
146 |
+
additional_context: Optional[Dict[str, Any]] = None,
|
147 |
+
pre_hooks: Optional[List[Callable]] = None,
|
148 |
+
post_hooks: Optional[List[Callable]] = None
|
149 |
+
) -> Dict[str, Any]:
|
150 |
+
"""Execute a structured LLM query using a template."""
|
151 |
+
template_path = Path(__file__).parent / "prompt_templates" / f"{template_name}.json"
|
152 |
+
|
153 |
+
try:
|
154 |
+
# Load and parse template
|
155 |
+
with open(template_path) as f:
|
156 |
+
template = json.load(f)
|
157 |
+
|
158 |
+
# Apply pre-processing hooks
|
159 |
+
processed_input = input_text
|
160 |
+
if pre_hooks:
|
161 |
+
for hook in pre_hooks:
|
162 |
+
processed_input = hook(processed_input)
|
163 |
+
|
164 |
+
# Format the prompt with the context
|
165 |
+
context = {"input_text": processed_input}
|
166 |
+
if additional_context:
|
167 |
+
context.update(additional_context)
|
168 |
+
|
169 |
+
prompt = template["prompt_template"].format(**context)
|
170 |
+
|
171 |
+
# Make the request to the LLM
|
172 |
+
response = await self._make_request(
|
173 |
+
"POST",
|
174 |
+
"generate",
|
175 |
+
json={
|
176 |
+
"prompt": prompt,
|
177 |
+
"system_message": template.get("system_message"),
|
178 |
+
"max_new_tokens": 1000
|
179 |
+
}
|
180 |
+
)
|
181 |
+
|
182 |
+
# Extract JSON from response
|
183 |
+
data = response.json()
|
184 |
+
result = extract_json(data["generated_text"])
|
185 |
+
|
186 |
+
# Apply any additional post-processing hooks
|
187 |
+
if post_hooks:
|
188 |
+
for hook in post_hooks:
|
189 |
+
result = hook(result)
|
190 |
+
|
191 |
+
return result
|
192 |
+
|
193 |
+
except FileNotFoundError:
|
194 |
+
raise ValueError(f"Template {template_name} not found")
|
195 |
+
except Exception as e:
|
196 |
+
self.logger.error(f"Error in structured_llm_query: {str(e)}")
|
197 |
+
raise
|
198 |
+
|
199 |
+
async def expand_query(
|
200 |
+
self,
|
201 |
+
query: str,
|
202 |
+
system_message: Optional[str] = None
|
203 |
+
) -> Dict[str, Any]:
|
204 |
+
"""Expand a query for RAG processing."""
|
205 |
+
return await self.structured_llm_query(
|
206 |
+
template_name="query_expansion",
|
207 |
+
input_text=query,
|
208 |
+
additional_context={"system_message": system_message} if system_message else None
|
209 |
+
)
|
210 |
+
|
211 |
+
async def rerank_chunks(
|
212 |
+
self,
|
213 |
+
query: str,
|
214 |
+
chunks: List[str],
|
215 |
+
system_message: Optional[str] = None
|
216 |
+
) -> Dict[str, Any]:
|
217 |
+
"""Rerank text chunks based on their relevance to the query."""
|
218 |
+
# Format chunks as numbered list for better LLM processing
|
219 |
+
formatted_chunks = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
|
220 |
+
|
221 |
+
return await self.structured_llm_query(
|
222 |
+
template_name="chunk_rerank",
|
223 |
+
input_text=query,
|
224 |
+
additional_context={
|
225 |
+
"chunks": formatted_chunks,
|
226 |
+
"system_message": system_message
|
227 |
+
}
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
async def generate_stream(
|
232 |
self,
|
233 |
prompt: str,
|
main/prompt_templates/chunk_rerank.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "chunk_rerank",
|
3 |
+
"description": "Evaluate and rank text chunks based on their relevance to a query",
|
4 |
+
"system_message": "You are a helpful assistant that evaluates text chunks for their relevance to a query. You always respond in valid JSON format.",
|
5 |
+
"prompt_template": "Please analyze the following query and text chunks, ranking the chunks by their relevance and importance to answering the query. Prioritize chunks that contain specific, relevant information over general statements.\n\nQuery: {input_text}\n\nText Chunks to evaluate:\n{chunks}\n\nCreate a JSON response with the following fields:\n- original_query: the exact query\n- ranked_chunks: array of the top 5 most relevant chunks, ordered by importance (most important first)\n- got_chunks: set to false if no chunks were provided or if they're all irrelevant\n\nEnsure your response is valid JSON and contains only these fields.",
|
6 |
+
"response_schema": {
|
7 |
+
"type": "object",
|
8 |
+
"properties": {
|
9 |
+
"original_query": {
|
10 |
+
"type": "string",
|
11 |
+
"description": "The exact query being processed"
|
12 |
+
},
|
13 |
+
"ranked_chunks": {
|
14 |
+
"type": "array",
|
15 |
+
"items": {
|
16 |
+
"type": "string"
|
17 |
+
},
|
18 |
+
"maxItems": 5,
|
19 |
+
"description": "Top 5 most relevant chunks in order of importance"
|
20 |
+
},
|
21 |
+
"got_chunks": {
|
22 |
+
"type": "boolean",
|
23 |
+
"description": "Whether any relevant chunks were found"
|
24 |
+
}
|
25 |
+
},
|
26 |
+
"required": ["original_query", "ranked_chunks", "got_chunks"]
|
27 |
+
},
|
28 |
+
"example_response": {
|
29 |
+
"original_query": "What are the key principles of relativity?",
|
30 |
+
"ranked_chunks": [
|
31 |
+
"Einstein's theory of special relativity is based on two fundamental principles: the principle of relativity and the constancy of the speed of light.",
|
32 |
+
"The principle of relativity states that the laws of physics are the same in all inertial reference frames.",
|
33 |
+
"In special relativity, time dilation occurs when objects move at high speeds relative to one another.",
|
34 |
+
"Mass and energy are equivalent, as expressed in the famous equation E=mc².",
|
35 |
+
"The theory led to revolutionary predictions about space and time, including length contraction."
|
36 |
+
],
|
37 |
+
"got_chunks": true
|
38 |
+
}
|
39 |
+
}
|
main/prompt_templates/query_expansion.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "query_expansion",
|
3 |
+
"description": "Expand a query for RAG processing with additional context and search terms",
|
4 |
+
"system_message": "You are a helpful assistant that creates JSON responses. Always ensure your response is valid JSON.",
|
5 |
+
"prompt_template": "Please analyze this query and create a JSON response with the following fields:\n- original_query: the exact query as provided\n- expanded_query: a more detailed version of the query that might help in getting better answers\n- search_terms: a list of key terms that would be useful for searching related information\n- call_rag: set to false if this query doesn't require searching through external documents (like math problems, coding questions, or general knowledge)\n\nThe query is: \"{input_text}\"\n\nYour response must be valid JSON and contain only these fields. Do not include any other text.",
|
6 |
+
"response_schema": {
|
7 |
+
"type": "object",
|
8 |
+
"properties": {
|
9 |
+
"original_query": {"type": "string"},
|
10 |
+
"expanded_query": {"type": "string"},
|
11 |
+
"search_terms": {
|
12 |
+
"type": "array",
|
13 |
+
"items": {"type": "string"}
|
14 |
+
},
|
15 |
+
"call_rag": {"type": "boolean"}
|
16 |
+
},
|
17 |
+
"required": ["original_query", "expanded_query", "search_terms", "call_rag"]
|
18 |
+
},
|
19 |
+
"example_response": {
|
20 |
+
"original_query": "What is quantum entanglement?",
|
21 |
+
"expanded_query": "Explain quantum entanglement, its significance in quantum mechanics, and how it challenges classical physics",
|
22 |
+
"search_terms": ["quantum entanglement", "quantum mechanics", "EPR paradox", "quantum physics", "spooky action"],
|
23 |
+
"call_rag": true
|
24 |
+
}
|
25 |
+
}
|
main/routes.py
CHANGED
@@ -12,7 +12,7 @@ from .schemas import (
|
|
12 |
SystemStatusResponse,
|
13 |
ValidationResponse,
|
14 |
ChatCompletionRequest,
|
15 |
-
ChatCompletionResponse
|
16 |
)
|
17 |
|
18 |
router = APIRouter()
|
@@ -113,6 +113,43 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|
113 |
logger.error(f"Error in chat completion endpoint: {str(e)}")
|
114 |
raise HTTPException(status_code=500, detail=str(e))
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
@router.post("/embedding", response_model=EmbeddingResponse)
|
117 |
async def generate_embedding(request: EmbeddingRequest):
|
118 |
"""Generate embedding vector from text"""
|
|
|
12 |
SystemStatusResponse,
|
13 |
ValidationResponse,
|
14 |
ChatCompletionRequest,
|
15 |
+
ChatCompletionResponse, QueryExpansionResponse, QueryExpansionRequest, ChunkRerankResponse, ChunkRerankRequest
|
16 |
)
|
17 |
|
18 |
router = APIRouter()
|
|
|
113 |
logger.error(f"Error in chat completion endpoint: {str(e)}")
|
114 |
raise HTTPException(status_code=500, detail=str(e))
|
115 |
|
116 |
+
@router.post("/expand_query", response_model=QueryExpansionResponse)
|
117 |
+
async def expand_query(request: QueryExpansionRequest):
|
118 |
+
"""Expand a query for RAG processing"""
|
119 |
+
logger.info(f"Received query expansion request: {request.query[:50]}...")
|
120 |
+
try:
|
121 |
+
result = await api.expand_query(
|
122 |
+
query=request.query,
|
123 |
+
system_message=request.system_message
|
124 |
+
)
|
125 |
+
logger.info("Successfully expanded query")
|
126 |
+
return result
|
127 |
+
except FileNotFoundError as e:
|
128 |
+
logger.error(f"Template file not found: {str(e)}")
|
129 |
+
raise HTTPException(status_code=500, detail="Query expansion template not found")
|
130 |
+
except json.JSONDecodeError as e:
|
131 |
+
logger.error(f"Invalid JSON response from LLM: {str(e)}")
|
132 |
+
raise HTTPException(status_code=500, detail="Invalid response format from LLM")
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error in expand_query endpoint: {str(e)}")
|
135 |
+
raise HTTPException(status_code=500, detail=str(e))
|
136 |
+
|
137 |
+
@router.post("/rerank", response_model=ChunkRerankResponse)
|
138 |
+
async def rerank_chunks(request: ChunkRerankRequest):
|
139 |
+
"""Rerank chunks based on their relevance to the query"""
|
140 |
+
logger.info(f"Received reranking request for query: {request.query[:50]}...")
|
141 |
+
try:
|
142 |
+
result = await api.rerank_chunks(
|
143 |
+
query=request.query,
|
144 |
+
chunks=request.chunks,
|
145 |
+
system_message=request.system_message
|
146 |
+
)
|
147 |
+
logger.info(f"Successfully reranked {len(request.chunks)} chunks")
|
148 |
+
return result
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Error in rerank_chunks endpoint: {str(e)}")
|
151 |
+
raise HTTPException(status_code=500, detail=str(e))
|
152 |
+
|
153 |
@router.post("/embedding", response_model=EmbeddingResponse)
|
154 |
async def generate_embedding(request: EmbeddingRequest):
|
155 |
"""Generate embedding vector from text"""
|
main/schemas.py
CHANGED
@@ -1,7 +1,35 @@
|
|
1 |
-
|
|
|
|
|
2 |
from typing import List, Optional, Dict, Union
|
3 |
from time import time
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
class ChatMessage(BaseModel):
|
6 |
role: str
|
7 |
content: str
|
@@ -91,4 +119,26 @@ class ValidationResponse(BaseModel):
|
|
91 |
model_validation: Dict[str, bool]
|
92 |
folder_validation: Dict[str, bool]
|
93 |
overall_status: str
|
94 |
-
issues: List[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from pydantic import BaseModel, Field, create_model, ConfigDict
|
4 |
from typing import List, Optional, Dict, Union
|
5 |
from time import time
|
6 |
|
7 |
+
class QueryExpansionRequest(BaseModel):
|
8 |
+
query: str
|
9 |
+
system_message: Optional[str] = None
|
10 |
+
|
11 |
+
# Load the template to create the response model
|
12 |
+
template_path = Path(__file__).parent / "prompt_templates" / "query_expansion.json"
|
13 |
+
with open(template_path) as f:
|
14 |
+
template = json.load(f)
|
15 |
+
|
16 |
+
# Create model configuration with proper typing
|
17 |
+
model_config = ConfigDict(
|
18 |
+
json_schema_extra={
|
19 |
+
'example': template['example_response']
|
20 |
+
}
|
21 |
+
)
|
22 |
+
|
23 |
+
# Create the response model based on the template's schema
|
24 |
+
QueryExpansionResponse = create_model(
|
25 |
+
'QueryExpansionResponse',
|
26 |
+
original_query=(str, ...),
|
27 |
+
expanded_query=(str, ...),
|
28 |
+
search_terms=(List[str], ...),
|
29 |
+
call_rag=(bool, ...),
|
30 |
+
model_config=model_config
|
31 |
+
)
|
32 |
+
|
33 |
class ChatMessage(BaseModel):
|
34 |
role: str
|
35 |
content: str
|
|
|
119 |
model_validation: Dict[str, bool]
|
120 |
folder_validation: Dict[str, bool]
|
121 |
overall_status: str
|
122 |
+
issues: List[str]
|
123 |
+
|
124 |
+
class ChunkRerankRequest(BaseModel):
|
125 |
+
query: str
|
126 |
+
chunks: List[str]
|
127 |
+
system_message: Optional[str] = None
|
128 |
+
|
129 |
+
# Load example from template
|
130 |
+
template_path = Path(__file__).parent / "prompt_templates" / "chunk_rerank.json"
|
131 |
+
with open(template_path) as f:
|
132 |
+
template = json.load(f)
|
133 |
+
example = template['example_response']
|
134 |
+
|
135 |
+
class ChunkRerankResponse(BaseModel):
|
136 |
+
"""Response model for chunk reranking, based on template schema"""
|
137 |
+
original_query: str = Field(..., description="The exact query being processed")
|
138 |
+
ranked_chunks: List[str] = Field(..., description="Top 5 most relevant chunks in order of importance", max_items=5)
|
139 |
+
got_chunks: bool = Field(..., description="Whether any relevant chunks were found")
|
140 |
+
|
141 |
+
class Config:
|
142 |
+
json_schema_extra = {
|
143 |
+
"example": example
|
144 |
+
}
|
main/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for the inference API."""
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from typing import Dict, Any
|
5 |
+
|
6 |
+
def extract_json(text: str) -> Dict[str, Any]:
|
7 |
+
"""Extract JSON from text that might contain other content.
|
8 |
+
|
9 |
+
Handles cases like:
|
10 |
+
- Clean JSON: {"key": "value"}
|
11 |
+
- JSON with prefix: Sure! Here's your JSON: {"key": "value"}
|
12 |
+
- JSON with suffix: {"key": "value"} Let me know if you need anything else!
|
13 |
+
"""
|
14 |
+
# Find anything that looks like a JSON object
|
15 |
+
json_pattern = r'\{(?:[^{}]|(?R))*\}'
|
16 |
+
matches = re.finditer(json_pattern, text)
|
17 |
+
|
18 |
+
# Try each match until we find valid JSON
|
19 |
+
for match in matches:
|
20 |
+
try:
|
21 |
+
potential_json = match.group()
|
22 |
+
parsed = json.loads(potential_json)
|
23 |
+
return parsed
|
24 |
+
except json.JSONDecodeError:
|
25 |
+
continue
|
26 |
+
|
27 |
+
# If we couldn't find any valid JSON, raise an error
|
28 |
+
raise ValueError("No valid JSON found in response")
|