mtyrrell commited on
Commit
e13001c
·
1 Parent(s): 915f480
Files changed (1) hide show
  1. app/main.py +67 -313
app/main.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from fastapi import FastAPI
3
  from langserve import add_routes
4
  from langgraph.graph import StateGraph, START, END
5
- from typing import Optional, Dict, Any, List, Literal
6
  from typing_extensions import TypedDict
7
  from pydantic import BaseModel
8
  from gradio_client import Client
@@ -11,26 +11,18 @@ import os
11
  from datetime import datetime
12
  import logging
13
  from contextlib import asynccontextmanager
14
- import io
15
- from PIL import Image
16
  import threading
17
  from langchain_core.runnables import RunnableLambda
18
 
19
- # Local imports
20
  from utils import getconfig
21
 
22
  config = getconfig("params.cfg")
23
-
24
  RETRIEVER = config.get("retriever", "RETRIEVER")
25
  GENERATOR = config.get("generator", "GENERATOR")
26
 
27
- logging.basicConfig(
28
- level=logging.INFO,
29
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
- )
31
  logger = logging.getLogger(__name__)
32
 
33
- # Define langgraph state schema
34
  class GraphState(TypedDict):
35
  query: str
36
  context: str
@@ -41,7 +33,6 @@ class GraphState(TypedDict):
41
  year_filter: str
42
  metadata: Optional[Dict[str, Any]]
43
 
44
- # LangServe input/output schemas
45
  class ChatFedInput(TypedDict):
46
  query: str
47
  reports_filter: Optional[str]
@@ -55,24 +46,12 @@ class ChatFedOutput(TypedDict):
55
  result: str
56
  metadata: Dict[str, Any]
57
 
58
- # ChatUI specific schemas
59
- class ChatUIStreamInput(BaseModel):
60
- text: str # ChatUI sends input as "text" field
61
-
62
- class ChatUIStreamOutput(BaseModel):
63
- content: str
64
-
65
- class ChatMessage(BaseModel):
66
- role: Literal["system", "user", "assistant"]
67
- content: str
68
-
69
  class ChatUIInput(BaseModel):
70
- messages: List[ChatMessage]
71
 
72
- # Retriever
73
  def retrieve_node(state: GraphState) -> GraphState:
74
  start_time = datetime.now()
75
- logger.info(f"Starting retrieval for query: {state['query'][:100]}...")
76
 
77
  try:
78
  client = Client(RETRIEVER)
@@ -88,30 +67,28 @@ def retrieve_node(state: GraphState) -> GraphState:
88
  duration = (datetime.now() - start_time).total_seconds()
89
  metadata = state.get("metadata", {})
90
  metadata.update({
91
- "retrieval_duration_seconds": duration,
92
  "context_length": len(context) if context else 0,
93
  "retrieval_success": True
94
  })
95
 
96
- logger.info(f"Retrieval completed in {duration:.2f}s, context length: {len(context) if context else 0}")
97
  return {"context": context, "metadata": metadata}
98
 
99
  except Exception as e:
100
  duration = (datetime.now() - start_time).total_seconds()
101
- logger.error(f"Retrieval failed after {duration:.2f}s: {str(e)}")
102
 
103
  metadata = state.get("metadata", {})
104
  metadata.update({
105
- "retrieval_duration_seconds": duration,
106
  "retrieval_success": False,
107
  "retrieval_error": str(e)
108
  })
109
  return {"context": "", "metadata": metadata}
110
 
111
- # Generator
112
  def generate_node(state: GraphState) -> GraphState:
113
  start_time = datetime.now()
114
- logger.info(f"Starting generation for query: {state['query'][:100]}...")
115
 
116
  try:
117
  client = Client(GENERATOR)
@@ -124,27 +101,25 @@ def generate_node(state: GraphState) -> GraphState:
124
  duration = (datetime.now() - start_time).total_seconds()
125
  metadata = state.get("metadata", {})
126
  metadata.update({
127
- "generation_duration_seconds": duration,
128
  "result_length": len(result) if result else 0,
129
  "generation_success": True
130
  })
131
 
132
- logger.info(f"Generation completed in {duration:.2f}s, result length: {len(result) if result else 0}")
133
  return {"result": result, "metadata": metadata}
134
 
135
  except Exception as e:
136
  duration = (datetime.now() - start_time).total_seconds()
137
- logger.error(f"Generation failed after {duration:.2f}s: {str(e)}")
138
 
139
  metadata = state.get("metadata", {})
140
  metadata.update({
141
- "generation_duration_seconds": duration,
142
  "generation_success": False,
143
  "generation_error": str(e)
144
  })
145
- return {"result": f"Error generating response: {str(e)}", "metadata": metadata}
146
 
147
- # Build graph
148
  workflow = StateGraph(GraphState)
149
  workflow.add_node("retrieve", retrieve_node)
150
  workflow.add_node("generate", generate_node)
@@ -153,8 +128,7 @@ workflow.add_edge("retrieve", "generate")
153
  workflow.add_edge("generate", END)
154
  compiled_graph = workflow.compile()
155
 
156
- # Core processing function (shared by both Gradio and LangServe)
157
- def process_chatfed_query_core(
158
  query: str,
159
  reports_filter: str = "",
160
  sources_filter: str = "",
@@ -164,13 +138,10 @@ def process_chatfed_query_core(
164
  user_id: Optional[str] = None,
165
  return_metadata: bool = False
166
  ):
167
- """Core processing function used by both Gradio and LangServe interfaces."""
168
  start_time = datetime.now()
169
  if not session_id:
170
  session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
171
 
172
- logger.info(f"Processing query in session {session_id}: {query[:100]}...")
173
-
174
  try:
175
  initial_state = {
176
  "query": query,
@@ -183,8 +154,7 @@ def process_chatfed_query_core(
183
  "metadata": {
184
  "session_id": session_id,
185
  "user_id": user_id,
186
- "start_time": start_time.isoformat(),
187
- "orchestrator": "hybrid_gradio_langserve"
188
  }
189
  }
190
 
@@ -193,13 +163,11 @@ def process_chatfed_query_core(
193
 
194
  final_metadata = final_state.get("metadata", {})
195
  final_metadata.update({
196
- "total_duration_seconds": total_duration,
197
  "end_time": datetime.now().isoformat(),
198
  "pipeline_success": True
199
  })
200
 
201
- logger.info(f"Query processing completed in {total_duration:.2f}s for session {session_id}")
202
-
203
  if return_metadata:
204
  return {"result": final_state["result"], "metadata": final_metadata}
205
  else:
@@ -207,33 +175,22 @@ def process_chatfed_query_core(
207
 
208
  except Exception as e:
209
  total_duration = (datetime.now() - start_time).total_seconds()
210
- logger.error(f"Pipeline failed after {total_duration:.2f}s for session {session_id}: {str(e)}")
211
 
212
  if return_metadata:
213
  error_metadata = {
214
  "session_id": session_id,
215
- "total_duration_seconds": total_duration,
216
  "pipeline_success": False,
217
  "error": str(e)
218
  }
219
- return {"result": f"Error processing query: {str(e)}", "metadata": error_metadata}
220
  else:
221
- return f"Error processing query: {str(e)}"
222
-
223
- # =============================================================================
224
- # GRADIO INTERFACE (MCP ENDPOINTS)
225
- # =============================================================================
226
 
227
- # Gradio wrapper functions for MCP compatibility
228
- def process_query_gradio(
229
- query: str,
230
- reports_filter: str = "",
231
- sources_filter: str = "",
232
- subtype_filter: str = "",
233
- year_filter: str = ""
234
- ) -> str:
235
- """Gradio-compatible function that exposes MCP endpoints."""
236
- return process_chatfed_query_core(
237
  query=query,
238
  reports_filter=reports_filter,
239
  sources_filter=sources_filter,
@@ -243,99 +200,7 @@ def process_query_gradio(
243
  return_metadata=False
244
  )
245
 
246
- def get_graph_visualization():
247
- """Generate graph visualization for Gradio interface."""
248
- try:
249
- graph_png_bytes = compiled_graph.get_graph().draw_mermaid_png()
250
- return Image.open(io.BytesIO(graph_png_bytes))
251
- except Exception as e:
252
- logger.error(f"Failed to generate graph visualization: {e}")
253
- return None
254
-
255
- # Create Gradio interface
256
- def create_gradio_interface():
257
- with gr.Blocks(title="ChatFed Orchestrator - MCP Endpoints") as demo:
258
- gr.Markdown("# ChatFed Orchestrator")
259
- gr.Markdown("**MCP Server Endpoints Available** - This interface provides MCP compatibility for ChatUI integration.")
260
-
261
- with gr.Row():
262
- with gr.Column(scale=1):
263
- gr.Markdown("**Workflow Visualization**")
264
- graph_display = gr.Image(
265
- value=get_graph_visualization(),
266
- label="LangGraph Workflow",
267
- interactive=False,
268
- height=300
269
- )
270
- refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm")
271
- refresh_graph_btn.click(fn=get_graph_visualization, outputs=graph_display)
272
-
273
- gr.Markdown("**🔗 MCP Integration**")
274
- gr.Markdown("MCP endpoints are active and ready for ChatUI integration.")
275
-
276
- with gr.Column(scale=2):
277
- gr.Markdown("**MCP Endpoint Information**")
278
-
279
- with gr.Accordion("MCP Usage", open=True):
280
- gr.Markdown("""
281
- **MCP Server Endpoint:** Available at `/gradio_api/mcp/sse`
282
-
283
- **For ChatUI Integration:**
284
- ```python
285
- from gradio_client import Client
286
-
287
- # Connect to orchestrator MCP endpoint
288
- client = Client("https://your-space.hf.space")
289
-
290
- # Basic usage
291
- response = client.predict(
292
- query="your question",
293
- api_name="/process_query_gradio"
294
- )
295
-
296
- # With filters
297
- response = client.predict(
298
- query="your question",
299
- reports_filter="annual_reports",
300
- sources_filter="internal",
301
- year_filter="2024",
302
- api_name="/process_query_gradio"
303
- )
304
- ```
305
- """)
306
-
307
- with gr.Accordion("Test Interface", open=False):
308
- # Test interface
309
- with gr.Row():
310
- with gr.Column():
311
- query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
312
- reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
313
- sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
314
- subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
315
- year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
316
- submit_btn = gr.Button("Submit", variant="primary")
317
-
318
- with gr.Column():
319
- output = gr.Textbox(label="Response", lines=10)
320
-
321
- submit_btn.click(
322
- fn=process_query_gradio,
323
- inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
324
- outputs=output
325
- )
326
-
327
- return demo
328
-
329
- # =============================================================================
330
- # CHATUI STREAMING ADAPTER
331
- # =============================================================================
332
-
333
- def chatui_streaming_adapter(data) -> str:
334
- """
335
- Adapter for ChatUI integration.
336
- LangServe will automatically create /invoke and /stream endpoints.
337
- ChatUI will use the /stream endpoint for streaming responses.
338
- """
339
  try:
340
  # Handle both dict and Pydantic model input
341
  if hasattr(data, 'text'):
@@ -343,58 +208,21 @@ def chatui_streaming_adapter(data) -> str:
343
  elif isinstance(data, dict) and 'text' in data:
344
  text = data['text']
345
  else:
346
- # Log the actual input structure for debugging
347
  logger.error(f"Unexpected input structure: {data}")
348
  return "Error: Invalid input format. Expected 'text' field."
349
 
350
- logger.info(f"ChatUI request: {text[:100]}...")
351
-
352
- # Process the query using your core function
353
- result = process_chatfed_query_core(
354
  query=text,
355
  session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
356
  return_metadata=False
357
  )
358
-
359
  return result
360
-
361
  except Exception as e:
362
  logger.error(f"ChatUI error: {str(e)}")
363
- return f"Error processing request: {str(e)}"
364
 
365
- def chatui_non_streaming_adapter(data) -> ChatUIStreamOutput:
366
- """
367
- Non-streaming adapter for ChatUI (fallback).
368
- """
369
- try:
370
- # Handle both dict and Pydantic model input
371
- if hasattr(data, 'text'):
372
- text = data.text
373
- elif isinstance(data, dict) and 'text' in data:
374
- text = data['text']
375
- else:
376
- logger.error(f"Unexpected input structure: {data}")
377
- return ChatUIStreamOutput(content="Error: Invalid input format. Expected 'text' field.")
378
-
379
- logger.info(f"ChatUI non-streaming request: {text[:100]}...")
380
-
381
- result = process_chatfed_query_core(
382
- query=text,
383
- session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
384
- return_metadata=False
385
- )
386
- return ChatUIStreamOutput(content=result)
387
- except Exception as e:
388
- logger.error(f"ChatUI adapter error: {str(e)}")
389
- return ChatUIStreamOutput(content=f"Error processing request: {str(e)}")
390
-
391
- # =============================================================================
392
- # LANGSERVE API (TELEMETRY)
393
- # =============================================================================
394
-
395
- def process_chatfed_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
396
- """LangServe function with full metadata return."""
397
- result = process_chatfed_query_core(
398
  query=input_data["query"],
399
  reports_filter=input_data.get("reports_filter", ""),
400
  sources_filter=input_data.get("sources_filter", ""),
@@ -406,147 +234,84 @@ def process_chatfed_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
406
  )
407
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
408
 
409
- def chatui_adapter(data: ChatUIInput):
410
- """
411
- Adapter to allow ChatUI to send full chat history.
412
- We extract the latest user message for ChatFed.
413
- """
414
- last_user_msg = next(m.content for m in reversed(data.messages) if m.role == "user")
415
- result = process_chatfed_query_core(query=last_user_msg)
416
- return {"result": result, "metadata": {"source": "chatfed-langserve-adapter"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  @asynccontextmanager
419
  async def lifespan(app: FastAPI):
420
- logger.info("🚀 Hybrid ChatFed Orchestrator starting up...")
421
- logger.info("✅ LangGraph compiled successfully")
422
- logger.info("🔗 MCP endpoints will be available via Gradio")
423
- logger.info("📊 Enhanced API available via LangServe")
424
- logger.info("🎯 ChatUI streaming integration enabled")
425
  yield
426
- logger.info("🛑 Orchestrator shutting down...")
427
 
428
- # Create FastAPI app with docs disabled
429
  app = FastAPI(
430
- title="ChatFed Orchestrator - Enhanced API",
431
  version="1.0.0",
432
- description="Enhanced API with observability. MCP endpoints available via Gradio interface.",
433
  lifespan=lifespan,
434
- docs_url=None, # Disable /docs endpoint
435
- redoc_url=None # Disable /redoc endpoint
436
  )
437
 
438
- # Health check
439
  @app.get("/health")
440
  async def health_check():
441
- return {
442
- "status": "healthy",
443
- "mcp_endpoints": "available_via_gradio",
444
- "enhanced_api": "available_via_langserve",
445
- "chatui_integration": "enabled"
446
- }
447
 
448
- # Add root endpoint
449
  @app.get("/")
450
  async def root():
451
  return {
452
  "message": "ChatFed Orchestrator API",
453
- "version": "1.0.0",
454
  "endpoints": {
455
  "health": "/health",
456
  "chatfed": "/chatfed",
457
- "chatfed-chatui": "/chatfed-chatui",
458
- "chatfed-ui-stream": "/chatfed-ui-stream", # New for ChatUI streaming
459
- "chatfed-ui": "/chatfed-ui", # New fallback
460
- "process_query": "/process_query"
461
- },
462
- "gradio_interface": "http://localhost:7861/",
463
- "mcp_endpoints": "http://localhost:7861/gradio_api/mcp/sse",
464
- "note": "LangServe telemetry enabled - ChatUI integration available via /chatfed-ui-stream"
465
  }
466
 
467
- # =============================================================================
468
- # ADD LANGSERVE ROUTES
469
- # =============================================================================
470
-
471
- # Convert functions to Runnables
472
- process_chatfed_query_runnable = RunnableLambda(process_chatfed_query_langserve)
473
- chatui_adapter_runnable = RunnableLambda(chatui_adapter)
474
- chatui_streaming_runnable = RunnableLambda(chatui_streaming_adapter)
475
- chatui_non_streaming_runnable = RunnableLambda(chatui_non_streaming_adapter)
476
-
477
- # Add routes with explicit input/output schemas
478
  add_routes(
479
  app,
480
- process_chatfed_query_runnable,
481
  path="/chatfed",
482
  input_type=ChatFedInput,
483
  output_type=ChatFedOutput
484
  )
485
 
486
- # Original ChatUI-compatible LangServe route
487
  add_routes(
488
  app,
489
- chatui_adapter_runnable,
490
- path="/chatfed-chatui",
491
- input_type=ChatUIInput
492
- )
493
-
494
- # NEW: ChatUI streaming route (matches your ChatUI config)
495
- # LangServe will automatically create both /invoke and /stream endpoints
496
- add_routes(
497
- app,
498
- chatui_streaming_runnable,
499
  path="/chatfed-ui-stream",
500
- input_type=ChatUIStreamInput,
501
  output_type=str,
502
  enable_feedback_endpoint=True,
503
  enable_public_trace_link_endpoint=True,
504
  )
505
 
506
- # NEW: ChatUI non-streaming fallback route
507
- add_routes(
508
- app,
509
- chatui_non_streaming_runnable,
510
- path="/chatfed-ui",
511
- input_type=ChatUIStreamInput,
512
- output_type=ChatUIStreamOutput,
513
- enable_feedback_endpoint=True,
514
- enable_public_trace_link_endpoint=True,
515
- )
516
-
517
- # Backward compatibility endpoint
518
- @app.post("/process_query")
519
- async def process_query_endpoint(
520
- query: str,
521
- reports_filter: str = "",
522
- sources_filter: str = "",
523
- subtype_filter: str = "",
524
- year_filter: str = "",
525
- session_id: Optional[str] = None,
526
- user_id: Optional[str] = None
527
- ):
528
- """Backward compatibility endpoint."""
529
- return process_chatfed_query_core(
530
- query=query,
531
- reports_filter=reports_filter,
532
- sources_filter=sources_filter,
533
- subtype_filter=subtype_filter,
534
- year_filter=year_filter,
535
- session_id=session_id,
536
- user_id=user_id,
537
- return_metadata=False
538
- )
539
-
540
- # =============================================================================
541
- # MAIN APPLICATION LAUNCHER
542
- # =============================================================================
543
-
544
  def run_gradio_server():
545
- """Run Gradio server in a separate thread for MCP endpoints."""
546
  demo = create_gradio_interface()
547
  demo.launch(
548
  server_name="0.0.0.0",
549
- server_port=7861, # Different port from FastAPI
550
  mcp_server=True,
551
  show_error=True,
552
  share=False,
@@ -554,24 +319,13 @@ def run_gradio_server():
554
  )
555
 
556
  if __name__ == "__main__":
557
- # Start Gradio server in background thread for MCP endpoints
558
  gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
559
  gradio_thread.start()
560
- logger.info("🔗 Gradio MCP server started on port 7861")
561
 
562
- # Start FastAPI server for enhanced API
563
  host = os.getenv("HOST", "0.0.0.0")
564
  port = int(os.getenv("PORT", "7860"))
565
 
566
- logger.info(f"🚀 Starting FastAPI server on {host}:{port}")
567
- logger.info("📊 Enhanced API with LangServe telemetry available")
568
- logger.info("🔗 MCP endpoints available via Gradio on port 7861")
569
- logger.info("🎯 ChatUI streaming integration ready at /chatfed-ui-stream")
570
 
571
- uvicorn.run(
572
- app,
573
- host=host,
574
- port=port,
575
- log_level="info",
576
- access_log=True
577
- )
 
2
  from fastapi import FastAPI
3
  from langserve import add_routes
4
  from langgraph.graph import StateGraph, START, END
5
+ from typing import Optional, Dict, Any
6
  from typing_extensions import TypedDict
7
  from pydantic import BaseModel
8
  from gradio_client import Client
 
11
  from datetime import datetime
12
  import logging
13
  from contextlib import asynccontextmanager
 
 
14
  import threading
15
  from langchain_core.runnables import RunnableLambda
16
 
 
17
  from utils import getconfig
18
 
19
  config = getconfig("params.cfg")
 
20
  RETRIEVER = config.get("retriever", "RETRIEVER")
21
  GENERATOR = config.get("generator", "GENERATOR")
22
 
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
24
  logger = logging.getLogger(__name__)
25
 
 
26
  class GraphState(TypedDict):
27
  query: str
28
  context: str
 
33
  year_filter: str
34
  metadata: Optional[Dict[str, Any]]
35
 
 
36
  class ChatFedInput(TypedDict):
37
  query: str
38
  reports_filter: Optional[str]
 
46
  result: str
47
  metadata: Dict[str, Any]
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  class ChatUIInput(BaseModel):
50
+ text: str
51
 
 
52
  def retrieve_node(state: GraphState) -> GraphState:
53
  start_time = datetime.now()
54
+ logger.info(f"Retrieval: {state['query'][:50]}...")
55
 
56
  try:
57
  client = Client(RETRIEVER)
 
67
  duration = (datetime.now() - start_time).total_seconds()
68
  metadata = state.get("metadata", {})
69
  metadata.update({
70
+ "retrieval_duration": duration,
71
  "context_length": len(context) if context else 0,
72
  "retrieval_success": True
73
  })
74
 
 
75
  return {"context": context, "metadata": metadata}
76
 
77
  except Exception as e:
78
  duration = (datetime.now() - start_time).total_seconds()
79
+ logger.error(f"Retrieval failed: {str(e)}")
80
 
81
  metadata = state.get("metadata", {})
82
  metadata.update({
83
+ "retrieval_duration": duration,
84
  "retrieval_success": False,
85
  "retrieval_error": str(e)
86
  })
87
  return {"context": "", "metadata": metadata}
88
 
 
89
  def generate_node(state: GraphState) -> GraphState:
90
  start_time = datetime.now()
91
+ logger.info(f"Generation: {state['query'][:50]}...")
92
 
93
  try:
94
  client = Client(GENERATOR)
 
101
  duration = (datetime.now() - start_time).total_seconds()
102
  metadata = state.get("metadata", {})
103
  metadata.update({
104
+ "generation_duration": duration,
105
  "result_length": len(result) if result else 0,
106
  "generation_success": True
107
  })
108
 
 
109
  return {"result": result, "metadata": metadata}
110
 
111
  except Exception as e:
112
  duration = (datetime.now() - start_time).total_seconds()
113
+ logger.error(f"Generation failed: {str(e)}")
114
 
115
  metadata = state.get("metadata", {})
116
  metadata.update({
117
+ "generation_duration": duration,
118
  "generation_success": False,
119
  "generation_error": str(e)
120
  })
121
+ return {"result": f"Error: {str(e)}", "metadata": metadata}
122
 
 
123
  workflow = StateGraph(GraphState)
124
  workflow.add_node("retrieve", retrieve_node)
125
  workflow.add_node("generate", generate_node)
 
128
  workflow.add_edge("generate", END)
129
  compiled_graph = workflow.compile()
130
 
131
+ def process_query_core(
 
132
  query: str,
133
  reports_filter: str = "",
134
  sources_filter: str = "",
 
138
  user_id: Optional[str] = None,
139
  return_metadata: bool = False
140
  ):
 
141
  start_time = datetime.now()
142
  if not session_id:
143
  session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
144
 
 
 
145
  try:
146
  initial_state = {
147
  "query": query,
 
154
  "metadata": {
155
  "session_id": session_id,
156
  "user_id": user_id,
157
+ "start_time": start_time.isoformat()
 
158
  }
159
  }
160
 
 
163
 
164
  final_metadata = final_state.get("metadata", {})
165
  final_metadata.update({
166
+ "total_duration": total_duration,
167
  "end_time": datetime.now().isoformat(),
168
  "pipeline_success": True
169
  })
170
 
 
 
171
  if return_metadata:
172
  return {"result": final_state["result"], "metadata": final_metadata}
173
  else:
 
175
 
176
  except Exception as e:
177
  total_duration = (datetime.now() - start_time).total_seconds()
178
+ logger.error(f"Pipeline failed: {str(e)}")
179
 
180
  if return_metadata:
181
  error_metadata = {
182
  "session_id": session_id,
183
+ "total_duration": total_duration,
184
  "pipeline_success": False,
185
  "error": str(e)
186
  }
187
+ return {"result": f"Error: {str(e)}", "metadata": error_metadata}
188
  else:
189
+ return f"Error: {str(e)}"
 
 
 
 
190
 
191
+ def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "",
192
+ subtype_filter: str = "", year_filter: str = "") -> str:
193
+ return process_query_core(
 
 
 
 
 
 
 
194
  query=query,
195
  reports_filter=reports_filter,
196
  sources_filter=sources_filter,
 
200
  return_metadata=False
201
  )
202
 
203
+ def chatui_adapter(data) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
  # Handle both dict and Pydantic model input
206
  if hasattr(data, 'text'):
 
208
  elif isinstance(data, dict) and 'text' in data:
209
  text = data['text']
210
  else:
 
211
  logger.error(f"Unexpected input structure: {data}")
212
  return "Error: Invalid input format. Expected 'text' field."
213
 
214
+ result = process_query_core(
 
 
 
215
  query=text,
216
  session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
217
  return_metadata=False
218
  )
 
219
  return result
 
220
  except Exception as e:
221
  logger.error(f"ChatUI error: {str(e)}")
222
+ return f"Error: {str(e)}"
223
 
224
+ def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
225
+ result = process_query_core(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  query=input_data["query"],
227
  reports_filter=input_data.get("reports_filter", ""),
228
  sources_filter=input_data.get("sources_filter", ""),
 
234
  )
235
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
236
 
237
+ def create_gradio_interface():
238
+ with gr.Blocks(title="ChatFed Orchestrator") as demo:
239
+ gr.Markdown("# ChatFed Orchestrator")
240
+ gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`")
241
+
242
+ with gr.Row():
243
+ with gr.Column():
244
+ query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
245
+ reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
246
+ sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
247
+ subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
248
+ year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
249
+ submit_btn = gr.Button("Submit", variant="primary")
250
+
251
+ with gr.Column():
252
+ output = gr.Textbox(label="Response", lines=10)
253
+
254
+ submit_btn.click(
255
+ fn=process_query_gradio,
256
+ inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
257
+ outputs=output
258
+ )
259
+
260
+ return demo
261
 
262
  @asynccontextmanager
263
  async def lifespan(app: FastAPI):
264
+ logger.info("ChatFed Orchestrator starting up...")
 
 
 
 
265
  yield
266
+ logger.info("Orchestrator shutting down...")
267
 
 
268
  app = FastAPI(
269
+ title="ChatFed Orchestrator",
270
  version="1.0.0",
 
271
  lifespan=lifespan,
272
+ docs_url=None,
273
+ redoc_url=None
274
  )
275
 
 
276
  @app.get("/health")
277
  async def health_check():
278
+ return {"status": "healthy"}
 
 
 
 
 
279
 
 
280
  @app.get("/")
281
  async def root():
282
  return {
283
  "message": "ChatFed Orchestrator API",
 
284
  "endpoints": {
285
  "health": "/health",
286
  "chatfed": "/chatfed",
287
+ "chatfed-ui-stream": "/chatfed-ui-stream"
288
+ }
 
 
 
 
 
 
289
  }
290
 
291
+ # LangServe routes
 
 
 
 
 
 
 
 
 
 
292
  add_routes(
293
  app,
294
+ RunnableLambda(process_query_langserve),
295
  path="/chatfed",
296
  input_type=ChatFedInput,
297
  output_type=ChatFedOutput
298
  )
299
 
 
300
  add_routes(
301
  app,
302
+ RunnableLambda(chatui_adapter),
 
 
 
 
 
 
 
 
 
303
  path="/chatfed-ui-stream",
304
+ input_type=ChatUIInput,
305
  output_type=str,
306
  enable_feedback_endpoint=True,
307
  enable_public_trace_link_endpoint=True,
308
  )
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  def run_gradio_server():
 
311
  demo = create_gradio_interface()
312
  demo.launch(
313
  server_name="0.0.0.0",
314
+ server_port=7861,
315
  mcp_server=True,
316
  show_error=True,
317
  share=False,
 
319
  )
320
 
321
  if __name__ == "__main__":
 
322
  gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
323
  gradio_thread.start()
324
+ logger.info("Gradio MCP server started on port 7861")
325
 
 
326
  host = os.getenv("HOST", "0.0.0.0")
327
  port = int(os.getenv("PORT", "7860"))
328
 
329
+ logger.info(f"Starting FastAPI server on {host}:{port}")
 
 
 
330
 
331
+ uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)