AurelioAguirre commited on
Commit
eb5a3fb
·
1 Parent(s): ecd2385

Adding query expansion and reranker

Browse files
main/api.py CHANGED
@@ -1,10 +1,13 @@
 
 
 
1
  import httpx
2
- from typing import Optional, AsyncIterator, Dict, Any, Iterator, List
3
  import logging
4
  import asyncio
5
- import os
6
  from litserve import LitAPI
7
  from pydantic import BaseModel
 
8
 
9
 
10
  class GenerationResponse(BaseModel):
@@ -136,6 +139,95 @@ class InferenceApi(LitAPI):
136
  self.logger.error(f"Error in generate_response: {str(e)}")
137
  raise
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  async def generate_stream(
140
  self,
141
  prompt: str,
 
1
+ import json
2
+ from pathlib import Path
3
+
4
  import httpx
5
+ from typing import Optional, AsyncIterator, Dict, Any, Iterator, List, Callable
6
  import logging
7
  import asyncio
 
8
  from litserve import LitAPI
9
  from pydantic import BaseModel
10
+ from .utils import extract_json
11
 
12
 
13
  class GenerationResponse(BaseModel):
 
139
  self.logger.error(f"Error in generate_response: {str(e)}")
140
  raise
141
 
142
+ async def structured_llm_query(
143
+ self,
144
+ template_name: str,
145
+ input_text: str,
146
+ additional_context: Optional[Dict[str, Any]] = None,
147
+ pre_hooks: Optional[List[Callable]] = None,
148
+ post_hooks: Optional[List[Callable]] = None
149
+ ) -> Dict[str, Any]:
150
+ """Execute a structured LLM query using a template."""
151
+ template_path = Path(__file__).parent / "prompt_templates" / f"{template_name}.json"
152
+
153
+ try:
154
+ # Load and parse template
155
+ with open(template_path) as f:
156
+ template = json.load(f)
157
+
158
+ # Apply pre-processing hooks
159
+ processed_input = input_text
160
+ if pre_hooks:
161
+ for hook in pre_hooks:
162
+ processed_input = hook(processed_input)
163
+
164
+ # Format the prompt with the context
165
+ context = {"input_text": processed_input}
166
+ if additional_context:
167
+ context.update(additional_context)
168
+
169
+ prompt = template["prompt_template"].format(**context)
170
+
171
+ # Make the request to the LLM
172
+ response = await self._make_request(
173
+ "POST",
174
+ "generate",
175
+ json={
176
+ "prompt": prompt,
177
+ "system_message": template.get("system_message"),
178
+ "max_new_tokens": 1000
179
+ }
180
+ )
181
+
182
+ # Extract JSON from response
183
+ data = response.json()
184
+ result = extract_json(data["generated_text"])
185
+
186
+ # Apply any additional post-processing hooks
187
+ if post_hooks:
188
+ for hook in post_hooks:
189
+ result = hook(result)
190
+
191
+ return result
192
+
193
+ except FileNotFoundError:
194
+ raise ValueError(f"Template {template_name} not found")
195
+ except Exception as e:
196
+ self.logger.error(f"Error in structured_llm_query: {str(e)}")
197
+ raise
198
+
199
+ async def expand_query(
200
+ self,
201
+ query: str,
202
+ system_message: Optional[str] = None
203
+ ) -> Dict[str, Any]:
204
+ """Expand a query for RAG processing."""
205
+ return await self.structured_llm_query(
206
+ template_name="query_expansion",
207
+ input_text=query,
208
+ additional_context={"system_message": system_message} if system_message else None
209
+ )
210
+
211
+ async def rerank_chunks(
212
+ self,
213
+ query: str,
214
+ chunks: List[str],
215
+ system_message: Optional[str] = None
216
+ ) -> Dict[str, Any]:
217
+ """Rerank text chunks based on their relevance to the query."""
218
+ # Format chunks as numbered list for better LLM processing
219
+ formatted_chunks = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks))
220
+
221
+ return await self.structured_llm_query(
222
+ template_name="chunk_rerank",
223
+ input_text=query,
224
+ additional_context={
225
+ "chunks": formatted_chunks,
226
+ "system_message": system_message
227
+ }
228
+ )
229
+
230
+
231
  async def generate_stream(
232
  self,
233
  prompt: str,
main/prompt_templates/chunk_rerank.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "chunk_rerank",
3
+ "description": "Evaluate and rank text chunks based on their relevance to a query",
4
+ "system_message": "You are a helpful assistant that evaluates text chunks for their relevance to a query. You always respond in valid JSON format.",
5
+ "prompt_template": "Please analyze the following query and text chunks, ranking the chunks by their relevance and importance to answering the query. Prioritize chunks that contain specific, relevant information over general statements.\n\nQuery: {input_text}\n\nText Chunks to evaluate:\n{chunks}\n\nCreate a JSON response with the following fields:\n- original_query: the exact query\n- ranked_chunks: array of the top 5 most relevant chunks, ordered by importance (most important first)\n- got_chunks: set to false if no chunks were provided or if they're all irrelevant\n\nEnsure your response is valid JSON and contains only these fields.",
6
+ "response_schema": {
7
+ "type": "object",
8
+ "properties": {
9
+ "original_query": {
10
+ "type": "string",
11
+ "description": "The exact query being processed"
12
+ },
13
+ "ranked_chunks": {
14
+ "type": "array",
15
+ "items": {
16
+ "type": "string"
17
+ },
18
+ "maxItems": 5,
19
+ "description": "Top 5 most relevant chunks in order of importance"
20
+ },
21
+ "got_chunks": {
22
+ "type": "boolean",
23
+ "description": "Whether any relevant chunks were found"
24
+ }
25
+ },
26
+ "required": ["original_query", "ranked_chunks", "got_chunks"]
27
+ },
28
+ "example_response": {
29
+ "original_query": "What are the key principles of relativity?",
30
+ "ranked_chunks": [
31
+ "Einstein's theory of special relativity is based on two fundamental principles: the principle of relativity and the constancy of the speed of light.",
32
+ "The principle of relativity states that the laws of physics are the same in all inertial reference frames.",
33
+ "In special relativity, time dilation occurs when objects move at high speeds relative to one another.",
34
+ "Mass and energy are equivalent, as expressed in the famous equation E=mc².",
35
+ "The theory led to revolutionary predictions about space and time, including length contraction."
36
+ ],
37
+ "got_chunks": true
38
+ }
39
+ }
main/prompt_templates/query_expansion.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "query_expansion",
3
+ "description": "Expand a query for RAG processing with additional context and search terms",
4
+ "system_message": "You are a helpful assistant that creates JSON responses. Always ensure your response is valid JSON.",
5
+ "prompt_template": "Please analyze this query and create a JSON response with the following fields:\n- original_query: the exact query as provided\n- expanded_query: a more detailed version of the query that might help in getting better answers\n- search_terms: a list of key terms that would be useful for searching related information\n- call_rag: set to false if this query doesn't require searching through external documents (like math problems, coding questions, or general knowledge)\n\nThe query is: \"{input_text}\"\n\nYour response must be valid JSON and contain only these fields. Do not include any other text.",
6
+ "response_schema": {
7
+ "type": "object",
8
+ "properties": {
9
+ "original_query": {"type": "string"},
10
+ "expanded_query": {"type": "string"},
11
+ "search_terms": {
12
+ "type": "array",
13
+ "items": {"type": "string"}
14
+ },
15
+ "call_rag": {"type": "boolean"}
16
+ },
17
+ "required": ["original_query", "expanded_query", "search_terms", "call_rag"]
18
+ },
19
+ "example_response": {
20
+ "original_query": "What is quantum entanglement?",
21
+ "expanded_query": "Explain quantum entanglement, its significance in quantum mechanics, and how it challenges classical physics",
22
+ "search_terms": ["quantum entanglement", "quantum mechanics", "EPR paradox", "quantum physics", "spooky action"],
23
+ "call_rag": true
24
+ }
25
+ }
main/routes.py CHANGED
@@ -12,7 +12,7 @@ from .schemas import (
12
  SystemStatusResponse,
13
  ValidationResponse,
14
  ChatCompletionRequest,
15
- ChatCompletionResponse
16
  )
17
 
18
  router = APIRouter()
@@ -113,6 +113,43 @@ async def create_chat_completion(request: ChatCompletionRequest):
113
  logger.error(f"Error in chat completion endpoint: {str(e)}")
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @router.post("/embedding", response_model=EmbeddingResponse)
117
  async def generate_embedding(request: EmbeddingRequest):
118
  """Generate embedding vector from text"""
 
12
  SystemStatusResponse,
13
  ValidationResponse,
14
  ChatCompletionRequest,
15
+ ChatCompletionResponse, QueryExpansionResponse, QueryExpansionRequest, ChunkRerankResponse, ChunkRerankRequest
16
  )
17
 
18
  router = APIRouter()
 
113
  logger.error(f"Error in chat completion endpoint: {str(e)}")
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
116
+ @router.post("/expand_query", response_model=QueryExpansionResponse)
117
+ async def expand_query(request: QueryExpansionRequest):
118
+ """Expand a query for RAG processing"""
119
+ logger.info(f"Received query expansion request: {request.query[:50]}...")
120
+ try:
121
+ result = await api.expand_query(
122
+ query=request.query,
123
+ system_message=request.system_message
124
+ )
125
+ logger.info("Successfully expanded query")
126
+ return result
127
+ except FileNotFoundError as e:
128
+ logger.error(f"Template file not found: {str(e)}")
129
+ raise HTTPException(status_code=500, detail="Query expansion template not found")
130
+ except json.JSONDecodeError as e:
131
+ logger.error(f"Invalid JSON response from LLM: {str(e)}")
132
+ raise HTTPException(status_code=500, detail="Invalid response format from LLM")
133
+ except Exception as e:
134
+ logger.error(f"Error in expand_query endpoint: {str(e)}")
135
+ raise HTTPException(status_code=500, detail=str(e))
136
+
137
+ @router.post("/rerank", response_model=ChunkRerankResponse)
138
+ async def rerank_chunks(request: ChunkRerankRequest):
139
+ """Rerank chunks based on their relevance to the query"""
140
+ logger.info(f"Received reranking request for query: {request.query[:50]}...")
141
+ try:
142
+ result = await api.rerank_chunks(
143
+ query=request.query,
144
+ chunks=request.chunks,
145
+ system_message=request.system_message
146
+ )
147
+ logger.info(f"Successfully reranked {len(request.chunks)} chunks")
148
+ return result
149
+ except Exception as e:
150
+ logger.error(f"Error in rerank_chunks endpoint: {str(e)}")
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
  @router.post("/embedding", response_model=EmbeddingResponse)
154
  async def generate_embedding(request: EmbeddingRequest):
155
  """Generate embedding vector from text"""
main/schemas.py CHANGED
@@ -1,7 +1,35 @@
1
- from pydantic import BaseModel, Field
 
 
2
  from typing import List, Optional, Dict, Union
3
  from time import time
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class ChatMessage(BaseModel):
6
  role: str
7
  content: str
@@ -91,4 +119,26 @@ class ValidationResponse(BaseModel):
91
  model_validation: Dict[str, bool]
92
  folder_validation: Dict[str, bool]
93
  overall_status: str
94
- issues: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from pydantic import BaseModel, Field, create_model, ConfigDict
4
  from typing import List, Optional, Dict, Union
5
  from time import time
6
 
7
+ class QueryExpansionRequest(BaseModel):
8
+ query: str
9
+ system_message: Optional[str] = None
10
+
11
+ # Load the template to create the response model
12
+ template_path = Path(__file__).parent / "prompt_templates" / "query_expansion.json"
13
+ with open(template_path) as f:
14
+ template = json.load(f)
15
+
16
+ # Create model configuration with proper typing
17
+ model_config = ConfigDict(
18
+ json_schema_extra={
19
+ 'example': template['example_response']
20
+ }
21
+ )
22
+
23
+ # Create the response model based on the template's schema
24
+ QueryExpansionResponse = create_model(
25
+ 'QueryExpansionResponse',
26
+ original_query=(str, ...),
27
+ expanded_query=(str, ...),
28
+ search_terms=(List[str], ...),
29
+ call_rag=(bool, ...),
30
+ model_config=model_config
31
+ )
32
+
33
  class ChatMessage(BaseModel):
34
  role: str
35
  content: str
 
119
  model_validation: Dict[str, bool]
120
  folder_validation: Dict[str, bool]
121
  overall_status: str
122
+ issues: List[str]
123
+
124
+ class ChunkRerankRequest(BaseModel):
125
+ query: str
126
+ chunks: List[str]
127
+ system_message: Optional[str] = None
128
+
129
+ # Load example from template
130
+ template_path = Path(__file__).parent / "prompt_templates" / "chunk_rerank.json"
131
+ with open(template_path) as f:
132
+ template = json.load(f)
133
+ example = template['example_response']
134
+
135
+ class ChunkRerankResponse(BaseModel):
136
+ """Response model for chunk reranking, based on template schema"""
137
+ original_query: str = Field(..., description="The exact query being processed")
138
+ ranked_chunks: List[str] = Field(..., description="Top 5 most relevant chunks in order of importance", max_items=5)
139
+ got_chunks: bool = Field(..., description="Whether any relevant chunks were found")
140
+
141
+ class Config:
142
+ json_schema_extra = {
143
+ "example": example
144
+ }
main/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for the inference API."""
2
+ import json
3
+ import re
4
+ from typing import Dict, Any
5
+
6
+ def extract_json(text: str) -> Dict[str, Any]:
7
+ """Extract JSON from text that might contain other content.
8
+
9
+ Handles cases like:
10
+ - Clean JSON: {"key": "value"}
11
+ - JSON with prefix: Sure! Here's your JSON: {"key": "value"}
12
+ - JSON with suffix: {"key": "value"} Let me know if you need anything else!
13
+ """
14
+ # Find anything that looks like a JSON object
15
+ json_pattern = r'\{(?:[^{}]|(?R))*\}'
16
+ matches = re.finditer(json_pattern, text)
17
+
18
+ # Try each match until we find valid JSON
19
+ for match in matches:
20
+ try:
21
+ potential_json = match.group()
22
+ parsed = json.loads(potential_json)
23
+ return parsed
24
+ except json.JSONDecodeError:
25
+ continue
26
+
27
+ # If we couldn't find any valid JSON, raise an error
28
+ raise ValueError("No valid JSON found in response")