Spaces:
Sleeping
Sleeping
cleanup
Browse files- app/main.py +67 -313
app/main.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
from fastapi import FastAPI
|
3 |
from langserve import add_routes
|
4 |
from langgraph.graph import StateGraph, START, END
|
5 |
-
from typing import Optional, Dict, Any
|
6 |
from typing_extensions import TypedDict
|
7 |
from pydantic import BaseModel
|
8 |
from gradio_client import Client
|
@@ -11,26 +11,18 @@ import os
|
|
11 |
from datetime import datetime
|
12 |
import logging
|
13 |
from contextlib import asynccontextmanager
|
14 |
-
import io
|
15 |
-
from PIL import Image
|
16 |
import threading
|
17 |
from langchain_core.runnables import RunnableLambda
|
18 |
|
19 |
-
# Local imports
|
20 |
from utils import getconfig
|
21 |
|
22 |
config = getconfig("params.cfg")
|
23 |
-
|
24 |
RETRIEVER = config.get("retriever", "RETRIEVER")
|
25 |
GENERATOR = config.get("generator", "GENERATOR")
|
26 |
|
27 |
-
logging.basicConfig(
|
28 |
-
level=logging.INFO,
|
29 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
30 |
-
)
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
-
# Define langgraph state schema
|
34 |
class GraphState(TypedDict):
|
35 |
query: str
|
36 |
context: str
|
@@ -41,7 +33,6 @@ class GraphState(TypedDict):
|
|
41 |
year_filter: str
|
42 |
metadata: Optional[Dict[str, Any]]
|
43 |
|
44 |
-
# LangServe input/output schemas
|
45 |
class ChatFedInput(TypedDict):
|
46 |
query: str
|
47 |
reports_filter: Optional[str]
|
@@ -55,24 +46,12 @@ class ChatFedOutput(TypedDict):
|
|
55 |
result: str
|
56 |
metadata: Dict[str, Any]
|
57 |
|
58 |
-
# ChatUI specific schemas
|
59 |
-
class ChatUIStreamInput(BaseModel):
|
60 |
-
text: str # ChatUI sends input as "text" field
|
61 |
-
|
62 |
-
class ChatUIStreamOutput(BaseModel):
|
63 |
-
content: str
|
64 |
-
|
65 |
-
class ChatMessage(BaseModel):
|
66 |
-
role: Literal["system", "user", "assistant"]
|
67 |
-
content: str
|
68 |
-
|
69 |
class ChatUIInput(BaseModel):
|
70 |
-
|
71 |
|
72 |
-
# Retriever
|
73 |
def retrieve_node(state: GraphState) -> GraphState:
|
74 |
start_time = datetime.now()
|
75 |
-
logger.info(f"
|
76 |
|
77 |
try:
|
78 |
client = Client(RETRIEVER)
|
@@ -88,30 +67,28 @@ def retrieve_node(state: GraphState) -> GraphState:
|
|
88 |
duration = (datetime.now() - start_time).total_seconds()
|
89 |
metadata = state.get("metadata", {})
|
90 |
metadata.update({
|
91 |
-
"
|
92 |
"context_length": len(context) if context else 0,
|
93 |
"retrieval_success": True
|
94 |
})
|
95 |
|
96 |
-
logger.info(f"Retrieval completed in {duration:.2f}s, context length: {len(context) if context else 0}")
|
97 |
return {"context": context, "metadata": metadata}
|
98 |
|
99 |
except Exception as e:
|
100 |
duration = (datetime.now() - start_time).total_seconds()
|
101 |
-
logger.error(f"Retrieval failed
|
102 |
|
103 |
metadata = state.get("metadata", {})
|
104 |
metadata.update({
|
105 |
-
"
|
106 |
"retrieval_success": False,
|
107 |
"retrieval_error": str(e)
|
108 |
})
|
109 |
return {"context": "", "metadata": metadata}
|
110 |
|
111 |
-
# Generator
|
112 |
def generate_node(state: GraphState) -> GraphState:
|
113 |
start_time = datetime.now()
|
114 |
-
logger.info(f"
|
115 |
|
116 |
try:
|
117 |
client = Client(GENERATOR)
|
@@ -124,27 +101,25 @@ def generate_node(state: GraphState) -> GraphState:
|
|
124 |
duration = (datetime.now() - start_time).total_seconds()
|
125 |
metadata = state.get("metadata", {})
|
126 |
metadata.update({
|
127 |
-
"
|
128 |
"result_length": len(result) if result else 0,
|
129 |
"generation_success": True
|
130 |
})
|
131 |
|
132 |
-
logger.info(f"Generation completed in {duration:.2f}s, result length: {len(result) if result else 0}")
|
133 |
return {"result": result, "metadata": metadata}
|
134 |
|
135 |
except Exception as e:
|
136 |
duration = (datetime.now() - start_time).total_seconds()
|
137 |
-
logger.error(f"Generation failed
|
138 |
|
139 |
metadata = state.get("metadata", {})
|
140 |
metadata.update({
|
141 |
-
"
|
142 |
"generation_success": False,
|
143 |
"generation_error": str(e)
|
144 |
})
|
145 |
-
return {"result": f"Error
|
146 |
|
147 |
-
# Build graph
|
148 |
workflow = StateGraph(GraphState)
|
149 |
workflow.add_node("retrieve", retrieve_node)
|
150 |
workflow.add_node("generate", generate_node)
|
@@ -153,8 +128,7 @@ workflow.add_edge("retrieve", "generate")
|
|
153 |
workflow.add_edge("generate", END)
|
154 |
compiled_graph = workflow.compile()
|
155 |
|
156 |
-
|
157 |
-
def process_chatfed_query_core(
|
158 |
query: str,
|
159 |
reports_filter: str = "",
|
160 |
sources_filter: str = "",
|
@@ -164,13 +138,10 @@ def process_chatfed_query_core(
|
|
164 |
user_id: Optional[str] = None,
|
165 |
return_metadata: bool = False
|
166 |
):
|
167 |
-
"""Core processing function used by both Gradio and LangServe interfaces."""
|
168 |
start_time = datetime.now()
|
169 |
if not session_id:
|
170 |
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
|
171 |
|
172 |
-
logger.info(f"Processing query in session {session_id}: {query[:100]}...")
|
173 |
-
|
174 |
try:
|
175 |
initial_state = {
|
176 |
"query": query,
|
@@ -183,8 +154,7 @@ def process_chatfed_query_core(
|
|
183 |
"metadata": {
|
184 |
"session_id": session_id,
|
185 |
"user_id": user_id,
|
186 |
-
"start_time": start_time.isoformat()
|
187 |
-
"orchestrator": "hybrid_gradio_langserve"
|
188 |
}
|
189 |
}
|
190 |
|
@@ -193,13 +163,11 @@ def process_chatfed_query_core(
|
|
193 |
|
194 |
final_metadata = final_state.get("metadata", {})
|
195 |
final_metadata.update({
|
196 |
-
"
|
197 |
"end_time": datetime.now().isoformat(),
|
198 |
"pipeline_success": True
|
199 |
})
|
200 |
|
201 |
-
logger.info(f"Query processing completed in {total_duration:.2f}s for session {session_id}")
|
202 |
-
|
203 |
if return_metadata:
|
204 |
return {"result": final_state["result"], "metadata": final_metadata}
|
205 |
else:
|
@@ -207,33 +175,22 @@ def process_chatfed_query_core(
|
|
207 |
|
208 |
except Exception as e:
|
209 |
total_duration = (datetime.now() - start_time).total_seconds()
|
210 |
-
logger.error(f"Pipeline failed
|
211 |
|
212 |
if return_metadata:
|
213 |
error_metadata = {
|
214 |
"session_id": session_id,
|
215 |
-
"
|
216 |
"pipeline_success": False,
|
217 |
"error": str(e)
|
218 |
}
|
219 |
-
return {"result": f"Error
|
220 |
else:
|
221 |
-
return f"Error
|
222 |
-
|
223 |
-
# =============================================================================
|
224 |
-
# GRADIO INTERFACE (MCP ENDPOINTS)
|
225 |
-
# =============================================================================
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
reports_filter: str = "",
|
231 |
-
sources_filter: str = "",
|
232 |
-
subtype_filter: str = "",
|
233 |
-
year_filter: str = ""
|
234 |
-
) -> str:
|
235 |
-
"""Gradio-compatible function that exposes MCP endpoints."""
|
236 |
-
return process_chatfed_query_core(
|
237 |
query=query,
|
238 |
reports_filter=reports_filter,
|
239 |
sources_filter=sources_filter,
|
@@ -243,99 +200,7 @@ def process_query_gradio(
|
|
243 |
return_metadata=False
|
244 |
)
|
245 |
|
246 |
-
def
|
247 |
-
"""Generate graph visualization for Gradio interface."""
|
248 |
-
try:
|
249 |
-
graph_png_bytes = compiled_graph.get_graph().draw_mermaid_png()
|
250 |
-
return Image.open(io.BytesIO(graph_png_bytes))
|
251 |
-
except Exception as e:
|
252 |
-
logger.error(f"Failed to generate graph visualization: {e}")
|
253 |
-
return None
|
254 |
-
|
255 |
-
# Create Gradio interface
|
256 |
-
def create_gradio_interface():
|
257 |
-
with gr.Blocks(title="ChatFed Orchestrator - MCP Endpoints") as demo:
|
258 |
-
gr.Markdown("# ChatFed Orchestrator")
|
259 |
-
gr.Markdown("**MCP Server Endpoints Available** - This interface provides MCP compatibility for ChatUI integration.")
|
260 |
-
|
261 |
-
with gr.Row():
|
262 |
-
with gr.Column(scale=1):
|
263 |
-
gr.Markdown("**Workflow Visualization**")
|
264 |
-
graph_display = gr.Image(
|
265 |
-
value=get_graph_visualization(),
|
266 |
-
label="LangGraph Workflow",
|
267 |
-
interactive=False,
|
268 |
-
height=300
|
269 |
-
)
|
270 |
-
refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm")
|
271 |
-
refresh_graph_btn.click(fn=get_graph_visualization, outputs=graph_display)
|
272 |
-
|
273 |
-
gr.Markdown("**🔗 MCP Integration**")
|
274 |
-
gr.Markdown("MCP endpoints are active and ready for ChatUI integration.")
|
275 |
-
|
276 |
-
with gr.Column(scale=2):
|
277 |
-
gr.Markdown("**MCP Endpoint Information**")
|
278 |
-
|
279 |
-
with gr.Accordion("MCP Usage", open=True):
|
280 |
-
gr.Markdown("""
|
281 |
-
**MCP Server Endpoint:** Available at `/gradio_api/mcp/sse`
|
282 |
-
|
283 |
-
**For ChatUI Integration:**
|
284 |
-
```python
|
285 |
-
from gradio_client import Client
|
286 |
-
|
287 |
-
# Connect to orchestrator MCP endpoint
|
288 |
-
client = Client("https://your-space.hf.space")
|
289 |
-
|
290 |
-
# Basic usage
|
291 |
-
response = client.predict(
|
292 |
-
query="your question",
|
293 |
-
api_name="/process_query_gradio"
|
294 |
-
)
|
295 |
-
|
296 |
-
# With filters
|
297 |
-
response = client.predict(
|
298 |
-
query="your question",
|
299 |
-
reports_filter="annual_reports",
|
300 |
-
sources_filter="internal",
|
301 |
-
year_filter="2024",
|
302 |
-
api_name="/process_query_gradio"
|
303 |
-
)
|
304 |
-
```
|
305 |
-
""")
|
306 |
-
|
307 |
-
with gr.Accordion("Test Interface", open=False):
|
308 |
-
# Test interface
|
309 |
-
with gr.Row():
|
310 |
-
with gr.Column():
|
311 |
-
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
|
312 |
-
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
|
313 |
-
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
|
314 |
-
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
|
315 |
-
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
|
316 |
-
submit_btn = gr.Button("Submit", variant="primary")
|
317 |
-
|
318 |
-
with gr.Column():
|
319 |
-
output = gr.Textbox(label="Response", lines=10)
|
320 |
-
|
321 |
-
submit_btn.click(
|
322 |
-
fn=process_query_gradio,
|
323 |
-
inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
|
324 |
-
outputs=output
|
325 |
-
)
|
326 |
-
|
327 |
-
return demo
|
328 |
-
|
329 |
-
# =============================================================================
|
330 |
-
# CHATUI STREAMING ADAPTER
|
331 |
-
# =============================================================================
|
332 |
-
|
333 |
-
def chatui_streaming_adapter(data) -> str:
|
334 |
-
"""
|
335 |
-
Adapter for ChatUI integration.
|
336 |
-
LangServe will automatically create /invoke and /stream endpoints.
|
337 |
-
ChatUI will use the /stream endpoint for streaming responses.
|
338 |
-
"""
|
339 |
try:
|
340 |
# Handle both dict and Pydantic model input
|
341 |
if hasattr(data, 'text'):
|
@@ -343,58 +208,21 @@ def chatui_streaming_adapter(data) -> str:
|
|
343 |
elif isinstance(data, dict) and 'text' in data:
|
344 |
text = data['text']
|
345 |
else:
|
346 |
-
# Log the actual input structure for debugging
|
347 |
logger.error(f"Unexpected input structure: {data}")
|
348 |
return "Error: Invalid input format. Expected 'text' field."
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
# Process the query using your core function
|
353 |
-
result = process_chatfed_query_core(
|
354 |
query=text,
|
355 |
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
356 |
return_metadata=False
|
357 |
)
|
358 |
-
|
359 |
return result
|
360 |
-
|
361 |
except Exception as e:
|
362 |
logger.error(f"ChatUI error: {str(e)}")
|
363 |
-
return f"Error
|
364 |
|
365 |
-
def
|
366 |
-
|
367 |
-
Non-streaming adapter for ChatUI (fallback).
|
368 |
-
"""
|
369 |
-
try:
|
370 |
-
# Handle both dict and Pydantic model input
|
371 |
-
if hasattr(data, 'text'):
|
372 |
-
text = data.text
|
373 |
-
elif isinstance(data, dict) and 'text' in data:
|
374 |
-
text = data['text']
|
375 |
-
else:
|
376 |
-
logger.error(f"Unexpected input structure: {data}")
|
377 |
-
return ChatUIStreamOutput(content="Error: Invalid input format. Expected 'text' field.")
|
378 |
-
|
379 |
-
logger.info(f"ChatUI non-streaming request: {text[:100]}...")
|
380 |
-
|
381 |
-
result = process_chatfed_query_core(
|
382 |
-
query=text,
|
383 |
-
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
384 |
-
return_metadata=False
|
385 |
-
)
|
386 |
-
return ChatUIStreamOutput(content=result)
|
387 |
-
except Exception as e:
|
388 |
-
logger.error(f"ChatUI adapter error: {str(e)}")
|
389 |
-
return ChatUIStreamOutput(content=f"Error processing request: {str(e)}")
|
390 |
-
|
391 |
-
# =============================================================================
|
392 |
-
# LANGSERVE API (TELEMETRY)
|
393 |
-
# =============================================================================
|
394 |
-
|
395 |
-
def process_chatfed_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
396 |
-
"""LangServe function with full metadata return."""
|
397 |
-
result = process_chatfed_query_core(
|
398 |
query=input_data["query"],
|
399 |
reports_filter=input_data.get("reports_filter", ""),
|
400 |
sources_filter=input_data.get("sources_filter", ""),
|
@@ -406,147 +234,84 @@ def process_chatfed_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
|
406 |
)
|
407 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
408 |
|
409 |
-
def
|
410 |
-
""
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
@asynccontextmanager
|
419 |
async def lifespan(app: FastAPI):
|
420 |
-
logger.info("
|
421 |
-
logger.info("✅ LangGraph compiled successfully")
|
422 |
-
logger.info("🔗 MCP endpoints will be available via Gradio")
|
423 |
-
logger.info("📊 Enhanced API available via LangServe")
|
424 |
-
logger.info("🎯 ChatUI streaming integration enabled")
|
425 |
yield
|
426 |
-
logger.info("
|
427 |
|
428 |
-
# Create FastAPI app with docs disabled
|
429 |
app = FastAPI(
|
430 |
-
title="ChatFed Orchestrator
|
431 |
version="1.0.0",
|
432 |
-
description="Enhanced API with observability. MCP endpoints available via Gradio interface.",
|
433 |
lifespan=lifespan,
|
434 |
-
docs_url=None,
|
435 |
-
redoc_url=None
|
436 |
)
|
437 |
|
438 |
-
# Health check
|
439 |
@app.get("/health")
|
440 |
async def health_check():
|
441 |
-
return {
|
442 |
-
"status": "healthy",
|
443 |
-
"mcp_endpoints": "available_via_gradio",
|
444 |
-
"enhanced_api": "available_via_langserve",
|
445 |
-
"chatui_integration": "enabled"
|
446 |
-
}
|
447 |
|
448 |
-
# Add root endpoint
|
449 |
@app.get("/")
|
450 |
async def root():
|
451 |
return {
|
452 |
"message": "ChatFed Orchestrator API",
|
453 |
-
"version": "1.0.0",
|
454 |
"endpoints": {
|
455 |
"health": "/health",
|
456 |
"chatfed": "/chatfed",
|
457 |
-
"chatfed-
|
458 |
-
|
459 |
-
"chatfed-ui": "/chatfed-ui", # New fallback
|
460 |
-
"process_query": "/process_query"
|
461 |
-
},
|
462 |
-
"gradio_interface": "http://localhost:7861/",
|
463 |
-
"mcp_endpoints": "http://localhost:7861/gradio_api/mcp/sse",
|
464 |
-
"note": "LangServe telemetry enabled - ChatUI integration available via /chatfed-ui-stream"
|
465 |
}
|
466 |
|
467 |
-
#
|
468 |
-
# ADD LANGSERVE ROUTES
|
469 |
-
# =============================================================================
|
470 |
-
|
471 |
-
# Convert functions to Runnables
|
472 |
-
process_chatfed_query_runnable = RunnableLambda(process_chatfed_query_langserve)
|
473 |
-
chatui_adapter_runnable = RunnableLambda(chatui_adapter)
|
474 |
-
chatui_streaming_runnable = RunnableLambda(chatui_streaming_adapter)
|
475 |
-
chatui_non_streaming_runnable = RunnableLambda(chatui_non_streaming_adapter)
|
476 |
-
|
477 |
-
# Add routes with explicit input/output schemas
|
478 |
add_routes(
|
479 |
app,
|
480 |
-
|
481 |
path="/chatfed",
|
482 |
input_type=ChatFedInput,
|
483 |
output_type=ChatFedOutput
|
484 |
)
|
485 |
|
486 |
-
# Original ChatUI-compatible LangServe route
|
487 |
add_routes(
|
488 |
app,
|
489 |
-
|
490 |
-
path="/chatfed-chatui",
|
491 |
-
input_type=ChatUIInput
|
492 |
-
)
|
493 |
-
|
494 |
-
# NEW: ChatUI streaming route (matches your ChatUI config)
|
495 |
-
# LangServe will automatically create both /invoke and /stream endpoints
|
496 |
-
add_routes(
|
497 |
-
app,
|
498 |
-
chatui_streaming_runnable,
|
499 |
path="/chatfed-ui-stream",
|
500 |
-
input_type=
|
501 |
output_type=str,
|
502 |
enable_feedback_endpoint=True,
|
503 |
enable_public_trace_link_endpoint=True,
|
504 |
)
|
505 |
|
506 |
-
# NEW: ChatUI non-streaming fallback route
|
507 |
-
add_routes(
|
508 |
-
app,
|
509 |
-
chatui_non_streaming_runnable,
|
510 |
-
path="/chatfed-ui",
|
511 |
-
input_type=ChatUIStreamInput,
|
512 |
-
output_type=ChatUIStreamOutput,
|
513 |
-
enable_feedback_endpoint=True,
|
514 |
-
enable_public_trace_link_endpoint=True,
|
515 |
-
)
|
516 |
-
|
517 |
-
# Backward compatibility endpoint
|
518 |
-
@app.post("/process_query")
|
519 |
-
async def process_query_endpoint(
|
520 |
-
query: str,
|
521 |
-
reports_filter: str = "",
|
522 |
-
sources_filter: str = "",
|
523 |
-
subtype_filter: str = "",
|
524 |
-
year_filter: str = "",
|
525 |
-
session_id: Optional[str] = None,
|
526 |
-
user_id: Optional[str] = None
|
527 |
-
):
|
528 |
-
"""Backward compatibility endpoint."""
|
529 |
-
return process_chatfed_query_core(
|
530 |
-
query=query,
|
531 |
-
reports_filter=reports_filter,
|
532 |
-
sources_filter=sources_filter,
|
533 |
-
subtype_filter=subtype_filter,
|
534 |
-
year_filter=year_filter,
|
535 |
-
session_id=session_id,
|
536 |
-
user_id=user_id,
|
537 |
-
return_metadata=False
|
538 |
-
)
|
539 |
-
|
540 |
-
# =============================================================================
|
541 |
-
# MAIN APPLICATION LAUNCHER
|
542 |
-
# =============================================================================
|
543 |
-
|
544 |
def run_gradio_server():
|
545 |
-
"""Run Gradio server in a separate thread for MCP endpoints."""
|
546 |
demo = create_gradio_interface()
|
547 |
demo.launch(
|
548 |
server_name="0.0.0.0",
|
549 |
-
server_port=7861,
|
550 |
mcp_server=True,
|
551 |
show_error=True,
|
552 |
share=False,
|
@@ -554,24 +319,13 @@ def run_gradio_server():
|
|
554 |
)
|
555 |
|
556 |
if __name__ == "__main__":
|
557 |
-
# Start Gradio server in background thread for MCP endpoints
|
558 |
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
|
559 |
gradio_thread.start()
|
560 |
-
logger.info("
|
561 |
|
562 |
-
# Start FastAPI server for enhanced API
|
563 |
host = os.getenv("HOST", "0.0.0.0")
|
564 |
port = int(os.getenv("PORT", "7860"))
|
565 |
|
566 |
-
logger.info(f"
|
567 |
-
logger.info("📊 Enhanced API with LangServe telemetry available")
|
568 |
-
logger.info("🔗 MCP endpoints available via Gradio on port 7861")
|
569 |
-
logger.info("🎯 ChatUI streaming integration ready at /chatfed-ui-stream")
|
570 |
|
571 |
-
uvicorn.run(
|
572 |
-
app,
|
573 |
-
host=host,
|
574 |
-
port=port,
|
575 |
-
log_level="info",
|
576 |
-
access_log=True
|
577 |
-
)
|
|
|
2 |
from fastapi import FastAPI
|
3 |
from langserve import add_routes
|
4 |
from langgraph.graph import StateGraph, START, END
|
5 |
+
from typing import Optional, Dict, Any
|
6 |
from typing_extensions import TypedDict
|
7 |
from pydantic import BaseModel
|
8 |
from gradio_client import Client
|
|
|
11 |
from datetime import datetime
|
12 |
import logging
|
13 |
from contextlib import asynccontextmanager
|
|
|
|
|
14 |
import threading
|
15 |
from langchain_core.runnables import RunnableLambda
|
16 |
|
|
|
17 |
from utils import getconfig
|
18 |
|
19 |
config = getconfig("params.cfg")
|
|
|
20 |
RETRIEVER = config.get("retriever", "RETRIEVER")
|
21 |
GENERATOR = config.get("generator", "GENERATOR")
|
22 |
|
23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
|
|
26 |
class GraphState(TypedDict):
|
27 |
query: str
|
28 |
context: str
|
|
|
33 |
year_filter: str
|
34 |
metadata: Optional[Dict[str, Any]]
|
35 |
|
|
|
36 |
class ChatFedInput(TypedDict):
|
37 |
query: str
|
38 |
reports_filter: Optional[str]
|
|
|
46 |
result: str
|
47 |
metadata: Dict[str, Any]
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class ChatUIInput(BaseModel):
|
50 |
+
text: str
|
51 |
|
|
|
52 |
def retrieve_node(state: GraphState) -> GraphState:
|
53 |
start_time = datetime.now()
|
54 |
+
logger.info(f"Retrieval: {state['query'][:50]}...")
|
55 |
|
56 |
try:
|
57 |
client = Client(RETRIEVER)
|
|
|
67 |
duration = (datetime.now() - start_time).total_seconds()
|
68 |
metadata = state.get("metadata", {})
|
69 |
metadata.update({
|
70 |
+
"retrieval_duration": duration,
|
71 |
"context_length": len(context) if context else 0,
|
72 |
"retrieval_success": True
|
73 |
})
|
74 |
|
|
|
75 |
return {"context": context, "metadata": metadata}
|
76 |
|
77 |
except Exception as e:
|
78 |
duration = (datetime.now() - start_time).total_seconds()
|
79 |
+
logger.error(f"Retrieval failed: {str(e)}")
|
80 |
|
81 |
metadata = state.get("metadata", {})
|
82 |
metadata.update({
|
83 |
+
"retrieval_duration": duration,
|
84 |
"retrieval_success": False,
|
85 |
"retrieval_error": str(e)
|
86 |
})
|
87 |
return {"context": "", "metadata": metadata}
|
88 |
|
|
|
89 |
def generate_node(state: GraphState) -> GraphState:
|
90 |
start_time = datetime.now()
|
91 |
+
logger.info(f"Generation: {state['query'][:50]}...")
|
92 |
|
93 |
try:
|
94 |
client = Client(GENERATOR)
|
|
|
101 |
duration = (datetime.now() - start_time).total_seconds()
|
102 |
metadata = state.get("metadata", {})
|
103 |
metadata.update({
|
104 |
+
"generation_duration": duration,
|
105 |
"result_length": len(result) if result else 0,
|
106 |
"generation_success": True
|
107 |
})
|
108 |
|
|
|
109 |
return {"result": result, "metadata": metadata}
|
110 |
|
111 |
except Exception as e:
|
112 |
duration = (datetime.now() - start_time).total_seconds()
|
113 |
+
logger.error(f"Generation failed: {str(e)}")
|
114 |
|
115 |
metadata = state.get("metadata", {})
|
116 |
metadata.update({
|
117 |
+
"generation_duration": duration,
|
118 |
"generation_success": False,
|
119 |
"generation_error": str(e)
|
120 |
})
|
121 |
+
return {"result": f"Error: {str(e)}", "metadata": metadata}
|
122 |
|
|
|
123 |
workflow = StateGraph(GraphState)
|
124 |
workflow.add_node("retrieve", retrieve_node)
|
125 |
workflow.add_node("generate", generate_node)
|
|
|
128 |
workflow.add_edge("generate", END)
|
129 |
compiled_graph = workflow.compile()
|
130 |
|
131 |
+
def process_query_core(
|
|
|
132 |
query: str,
|
133 |
reports_filter: str = "",
|
134 |
sources_filter: str = "",
|
|
|
138 |
user_id: Optional[str] = None,
|
139 |
return_metadata: bool = False
|
140 |
):
|
|
|
141 |
start_time = datetime.now()
|
142 |
if not session_id:
|
143 |
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
|
144 |
|
|
|
|
|
145 |
try:
|
146 |
initial_state = {
|
147 |
"query": query,
|
|
|
154 |
"metadata": {
|
155 |
"session_id": session_id,
|
156 |
"user_id": user_id,
|
157 |
+
"start_time": start_time.isoformat()
|
|
|
158 |
}
|
159 |
}
|
160 |
|
|
|
163 |
|
164 |
final_metadata = final_state.get("metadata", {})
|
165 |
final_metadata.update({
|
166 |
+
"total_duration": total_duration,
|
167 |
"end_time": datetime.now().isoformat(),
|
168 |
"pipeline_success": True
|
169 |
})
|
170 |
|
|
|
|
|
171 |
if return_metadata:
|
172 |
return {"result": final_state["result"], "metadata": final_metadata}
|
173 |
else:
|
|
|
175 |
|
176 |
except Exception as e:
|
177 |
total_duration = (datetime.now() - start_time).total_seconds()
|
178 |
+
logger.error(f"Pipeline failed: {str(e)}")
|
179 |
|
180 |
if return_metadata:
|
181 |
error_metadata = {
|
182 |
"session_id": session_id,
|
183 |
+
"total_duration": total_duration,
|
184 |
"pipeline_success": False,
|
185 |
"error": str(e)
|
186 |
}
|
187 |
+
return {"result": f"Error: {str(e)}", "metadata": error_metadata}
|
188 |
else:
|
189 |
+
return f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
190 |
|
191 |
+
def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "",
|
192 |
+
subtype_filter: str = "", year_filter: str = "") -> str:
|
193 |
+
return process_query_core(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
query=query,
|
195 |
reports_filter=reports_filter,
|
196 |
sources_filter=sources_filter,
|
|
|
200 |
return_metadata=False
|
201 |
)
|
202 |
|
203 |
+
def chatui_adapter(data) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
try:
|
205 |
# Handle both dict and Pydantic model input
|
206 |
if hasattr(data, 'text'):
|
|
|
208 |
elif isinstance(data, dict) and 'text' in data:
|
209 |
text = data['text']
|
210 |
else:
|
|
|
211 |
logger.error(f"Unexpected input structure: {data}")
|
212 |
return "Error: Invalid input format. Expected 'text' field."
|
213 |
|
214 |
+
result = process_query_core(
|
|
|
|
|
|
|
215 |
query=text,
|
216 |
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
217 |
return_metadata=False
|
218 |
)
|
|
|
219 |
return result
|
|
|
220 |
except Exception as e:
|
221 |
logger.error(f"ChatUI error: {str(e)}")
|
222 |
+
return f"Error: {str(e)}"
|
223 |
|
224 |
+
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
225 |
+
result = process_query_core(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
query=input_data["query"],
|
227 |
reports_filter=input_data.get("reports_filter", ""),
|
228 |
sources_filter=input_data.get("sources_filter", ""),
|
|
|
234 |
)
|
235 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
236 |
|
237 |
+
def create_gradio_interface():
|
238 |
+
with gr.Blocks(title="ChatFed Orchestrator") as demo:
|
239 |
+
gr.Markdown("# ChatFed Orchestrator")
|
240 |
+
gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`")
|
241 |
+
|
242 |
+
with gr.Row():
|
243 |
+
with gr.Column():
|
244 |
+
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
|
245 |
+
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
|
246 |
+
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
|
247 |
+
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
|
248 |
+
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
|
249 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
250 |
+
|
251 |
+
with gr.Column():
|
252 |
+
output = gr.Textbox(label="Response", lines=10)
|
253 |
+
|
254 |
+
submit_btn.click(
|
255 |
+
fn=process_query_gradio,
|
256 |
+
inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
|
257 |
+
outputs=output
|
258 |
+
)
|
259 |
+
|
260 |
+
return demo
|
261 |
|
262 |
@asynccontextmanager
|
263 |
async def lifespan(app: FastAPI):
|
264 |
+
logger.info("ChatFed Orchestrator starting up...")
|
|
|
|
|
|
|
|
|
265 |
yield
|
266 |
+
logger.info("Orchestrator shutting down...")
|
267 |
|
|
|
268 |
app = FastAPI(
|
269 |
+
title="ChatFed Orchestrator",
|
270 |
version="1.0.0",
|
|
|
271 |
lifespan=lifespan,
|
272 |
+
docs_url=None,
|
273 |
+
redoc_url=None
|
274 |
)
|
275 |
|
|
|
276 |
@app.get("/health")
|
277 |
async def health_check():
|
278 |
+
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
|
279 |
|
|
|
280 |
@app.get("/")
|
281 |
async def root():
|
282 |
return {
|
283 |
"message": "ChatFed Orchestrator API",
|
|
|
284 |
"endpoints": {
|
285 |
"health": "/health",
|
286 |
"chatfed": "/chatfed",
|
287 |
+
"chatfed-ui-stream": "/chatfed-ui-stream"
|
288 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
}
|
290 |
|
291 |
+
# LangServe routes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
add_routes(
|
293 |
app,
|
294 |
+
RunnableLambda(process_query_langserve),
|
295 |
path="/chatfed",
|
296 |
input_type=ChatFedInput,
|
297 |
output_type=ChatFedOutput
|
298 |
)
|
299 |
|
|
|
300 |
add_routes(
|
301 |
app,
|
302 |
+
RunnableLambda(chatui_adapter),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
path="/chatfed-ui-stream",
|
304 |
+
input_type=ChatUIInput,
|
305 |
output_type=str,
|
306 |
enable_feedback_endpoint=True,
|
307 |
enable_public_trace_link_endpoint=True,
|
308 |
)
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
def run_gradio_server():
|
|
|
311 |
demo = create_gradio_interface()
|
312 |
demo.launch(
|
313 |
server_name="0.0.0.0",
|
314 |
+
server_port=7861,
|
315 |
mcp_server=True,
|
316 |
show_error=True,
|
317 |
share=False,
|
|
|
319 |
)
|
320 |
|
321 |
if __name__ == "__main__":
|
|
|
322 |
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
|
323 |
gradio_thread.start()
|
324 |
+
logger.info("Gradio MCP server started on port 7861")
|
325 |
|
|
|
326 |
host = os.getenv("HOST", "0.0.0.0")
|
327 |
port = int(os.getenv("PORT", "7860"))
|
328 |
|
329 |
+
logger.info(f"Starting FastAPI server on {host}:{port}")
|
|
|
|
|
|
|
330 |
|
331 |
+
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)
|
|
|
|
|
|
|
|
|
|
|
|