mtyrrell commited on
Commit
9effd1f
Β·
1 Parent(s): acb99db

refactored for langserve

Browse files
Files changed (1) hide show
  1. app/main.py +520 -133
app/main.py CHANGED
@@ -1,179 +1,566 @@
1
  import gradio as gr
2
- from gradio_client import Client
 
3
  from langgraph.graph import StateGraph, START, END
4
- from typing import TypedDict, Optional
 
 
 
 
 
 
 
 
5
  import io
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
7
 
8
- #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
 
9
 
10
- # Define the state schema
 
 
 
 
 
 
11
  class GraphState(TypedDict):
12
  query: str
13
  context: str
14
  result: str
15
- # Add orchestrator-level parameters (addressing your open question)
16
  reports_filter: str
17
  sources_filter: str
18
  subtype_filter: str
19
  year_filter: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # node 2: retriever
22
  def retrieve_node(state: GraphState) -> GraphState:
23
- client = Client("giz/chatfed_retriever") # HF repo name
24
- context = client.predict(
25
- query=state["query"],
26
- reports_filter=state.get("reports_filter", ""),
27
- sources_filter=state.get("sources_filter", ""),
28
- subtype_filter=state.get("subtype_filter", ""),
29
- year_filter=state.get("year_filter", ""),
30
- api_name="/retrieve"
31
- )
32
- return {"context": context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # node 3: generator
35
  def generate_node(state: GraphState) -> GraphState:
36
- client = Client("giz/chatfed_generator")
37
- result = client.predict(
38
- query=state["query"],
39
- context=state["context"],
40
- api_name="/generate"
41
- )
42
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # build the graph
45
  workflow = StateGraph(GraphState)
46
-
47
- # Add nodes
48
  workflow.add_node("retrieve", retrieve_node)
49
  workflow.add_node("generate", generate_node)
50
-
51
- # Add edges
52
  workflow.add_edge(START, "retrieve")
53
  workflow.add_edge("retrieve", "generate")
54
  workflow.add_edge("generate", END)
 
55
 
56
- # Compile the graph
57
- graph = workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Single tool for processing queries
60
- def process_query(
61
  query: str,
62
  reports_filter: str = "",
63
  sources_filter: str = "",
64
  subtype_filter: str = "",
65
  year_filter: str = ""
66
  ) -> str:
67
- """
68
- Execute the ChatFed orchestration pipeline to process a user query.
69
-
70
- This function orchestrates a two-step workflow:
71
- 1. Retrieve relevant context using the ChatFed retriever service with optional filters
72
- 2. Generate a response using the ChatFed generator service with the retrieved context
73
-
74
- Args:
75
- query (str): The user's input query/question to be processed
76
- reports_filter (str, optional): Filter for specific report types. Defaults to "".
77
- sources_filter (str, optional): Filter for specific data sources. Defaults to "".
78
- subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
79
- year_filter (str, optional): Filter for specific years. Defaults to "".
80
-
81
- Returns:
82
- str: The generated response from the ChatFed generator service
83
- """
84
- initial_state = {
85
- "query": query,
86
- "context": "",
87
- "result": "",
88
- "reports_filter": reports_filter or "",
89
- "sources_filter": sources_filter or "",
90
- "subtype_filter": subtype_filter or "",
91
- "year_filter": year_filter or ""
92
- }
93
- final_state = graph.invoke(initial_state)
94
- return final_state["result"]
95
-
96
- # Simple testing interface
97
- ui = gr.Interface(
98
- fn=process_query,
99
- inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
100
- outputs="text",
101
- flagging_mode="never"
102
- )
103
 
104
- # Add a function to generate the graph visualization
105
  def get_graph_visualization():
106
- """Generate and return the LangGraph workflow visualization as a PIL Image."""
107
- # Generate the graph as PNG bytes
108
- graph_png_bytes = graph.get_graph().draw_mermaid_png()
109
-
110
- # Convert bytes to PIL Image for Gradio display
111
- graph_image = Image.open(io.BytesIO(graph_png_bytes))
112
- return graph_image
113
 
114
-
115
- # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
116
- with gr.Blocks(title="ChatFed Orchestrator") as demo:
117
- gr.Markdown("# ChatFed Orchestrator")
118
- gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call (which triggers the graph).")
119
-
120
- with gr.Row():
121
- # Left column - Graph visualization
122
- with gr.Column(scale=1):
123
- gr.Markdown("**Workflow Visualization**")
124
- graph_display = gr.Image(
125
- value=get_graph_visualization(),
126
- label="LangGraph Workflow",
127
- interactive=False,
128
- height=300
129
- )
130
-
131
- # Add a refresh button for the graph
132
- refresh_graph_btn = gr.Button("πŸ”„ Refresh Graph", size="sm")
133
- refresh_graph_btn.click(
134
- fn=get_graph_visualization,
135
- outputs=graph_display
136
- )
137
 
138
- # Right column - Interface and documentation
139
- with gr.Column(scale=2):
140
- gr.Markdown("**Available MCP Tools:**")
141
-
142
- with gr.Accordion("MCP Endpoint Information", open=True):
143
- gr.Markdown(f"""
144
- **MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse
 
 
 
 
145
 
146
- **For ChatUI Integration:**
147
- ```python
148
- from gradio_client import Client
149
 
150
- # Connect to orchestrator
151
- orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space")
152
 
153
- # Basic usage (no filters)
154
- response = orchestrator_client.predict(
155
- query="query",
156
- api_name="/process_query"
157
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Advanced usage with any combination of filters
160
- response = orchestrator_client.predict(
161
- query="query",
162
- reports_filter="annual_reports",
163
- sources_filter="internal",
164
- year_filter="2024",
165
- api_name="/process_query"
166
- )
167
- ```
168
- """)
169
 
170
- with gr.Accordion("Quick Testing Interface", open=True):
171
- ui.render()
 
172
 
173
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  demo.launch(
175
  server_name="0.0.0.0",
176
- server_port=7860,
177
- mcp_server=True,
178
- show_error=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
 
1
  import gradio as gr
2
+ from fastapi import FastAPI
3
+ from langserve import add_routes
4
  from langgraph.graph import StateGraph, START, END
5
+ from typing import Optional, Dict, Any, List, Literal, AsyncGenerator
6
+ from typing_extensions import TypedDict
7
+ from pydantic import BaseModel
8
+ from gradio_client import Client
9
+ import uvicorn
10
+ import os
11
+ from datetime import datetime
12
+ import logging
13
+ from contextlib import asynccontextmanager
14
  import io
15
  from PIL import Image
16
+ import threading
17
+ import json
18
+ import asyncio
19
+ from langchain_core.runnables import RunnableLambda
20
+ from langchain_core.output_parsers import StrOutputParser
21
+
22
+ # Local imports
23
+ from utils import getconfig
24
+
25
+ config = getconfig("params.cfg")
26
 
27
+ RETRIEVER = config.get("retriever", "RETRIEVER")
28
+ GENERATOR = config.get("generator", "GENERATOR")
29
 
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Define langgraph state schema
37
  class GraphState(TypedDict):
38
  query: str
39
  context: str
40
  result: str
 
41
  reports_filter: str
42
  sources_filter: str
43
  subtype_filter: str
44
  year_filter: str
45
+ metadata: Optional[Dict[str, Any]]
46
+
47
+ # LangServe input/output schemas
48
+ class ChatFedInput(TypedDict):
49
+ query: str
50
+ reports_filter: Optional[str]
51
+ sources_filter: Optional[str]
52
+ subtype_filter: Optional[str]
53
+ year_filter: Optional[str]
54
+ session_id: Optional[str]
55
+ user_id: Optional[str]
56
+
57
+ class ChatFedOutput(TypedDict):
58
+ result: str
59
+ metadata: Dict[str, Any]
60
+
61
+ # ChatUI specific schemas
62
+ class ChatUIStreamInput(BaseModel):
63
+ text: str # ChatUI sends input as "text" field
64
+
65
+ class ChatUIStreamOutput(BaseModel):
66
+ content: str
67
+
68
+ class ChatMessage(BaseModel):
69
+ role: Literal["system", "user", "assistant"]
70
+ content: str
71
+
72
+ class ChatUIInput(BaseModel):
73
+ messages: List[ChatMessage]
74
 
75
+ # Retriever
76
  def retrieve_node(state: GraphState) -> GraphState:
77
+ start_time = datetime.now()
78
+ logger.info(f"Starting retrieval for query: {state['query'][:100]}...")
79
+
80
+ try:
81
+ client = Client(RETRIEVER)
82
+ context = client.predict(
83
+ query=state["query"],
84
+ reports_filter=state.get("reports_filter", ""),
85
+ sources_filter=state.get("sources_filter", ""),
86
+ subtype_filter=state.get("subtype_filter", ""),
87
+ year_filter=state.get("year_filter", ""),
88
+ api_name="/retrieve"
89
+ )
90
+
91
+ duration = (datetime.now() - start_time).total_seconds()
92
+ metadata = state.get("metadata", {})
93
+ metadata.update({
94
+ "retrieval_duration_seconds": duration,
95
+ "context_length": len(context) if context else 0,
96
+ "retrieval_success": True
97
+ })
98
+
99
+ logger.info(f"Retrieval completed in {duration:.2f}s, context length: {len(context) if context else 0}")
100
+ return {"context": context, "metadata": metadata}
101
+
102
+ except Exception as e:
103
+ duration = (datetime.now() - start_time).total_seconds()
104
+ logger.error(f"Retrieval failed after {duration:.2f}s: {str(e)}")
105
+
106
+ metadata = state.get("metadata", {})
107
+ metadata.update({
108
+ "retrieval_duration_seconds": duration,
109
+ "retrieval_success": False,
110
+ "retrieval_error": str(e)
111
+ })
112
+ return {"context": "", "metadata": metadata}
113
 
114
+ # Generator
115
  def generate_node(state: GraphState) -> GraphState:
116
+ start_time = datetime.now()
117
+ logger.info(f"Starting generation for query: {state['query'][:100]}...")
118
+
119
+ try:
120
+ client = Client(GENERATOR)
121
+ result = client.predict(
122
+ query=state["query"],
123
+ context=state["context"],
124
+ api_name="/generate"
125
+ )
126
+
127
+ duration = (datetime.now() - start_time).total_seconds()
128
+ metadata = state.get("metadata", {})
129
+ metadata.update({
130
+ "generation_duration_seconds": duration,
131
+ "result_length": len(result) if result else 0,
132
+ "generation_success": True
133
+ })
134
+
135
+ logger.info(f"Generation completed in {duration:.2f}s, result length: {len(result) if result else 0}")
136
+ return {"result": result, "metadata": metadata}
137
+
138
+ except Exception as e:
139
+ duration = (datetime.now() - start_time).total_seconds()
140
+ logger.error(f"Generation failed after {duration:.2f}s: {str(e)}")
141
+
142
+ metadata = state.get("metadata", {})
143
+ metadata.update({
144
+ "generation_duration_seconds": duration,
145
+ "generation_success": False,
146
+ "generation_error": str(e)
147
+ })
148
+ return {"result": f"Error generating response: {str(e)}", "metadata": metadata}
149
 
150
+ # Build graph
151
  workflow = StateGraph(GraphState)
 
 
152
  workflow.add_node("retrieve", retrieve_node)
153
  workflow.add_node("generate", generate_node)
 
 
154
  workflow.add_edge(START, "retrieve")
155
  workflow.add_edge("retrieve", "generate")
156
  workflow.add_edge("generate", END)
157
+ compiled_graph = workflow.compile()
158
 
159
+ # Core processing function (shared by both Gradio and LangServe)
160
+ def process_chatfed_query_core(
161
+ query: str,
162
+ reports_filter: str = "",
163
+ sources_filter: str = "",
164
+ subtype_filter: str = "",
165
+ year_filter: str = "",
166
+ session_id: Optional[str] = None,
167
+ user_id: Optional[str] = None,
168
+ return_metadata: bool = False
169
+ ):
170
+ """Core processing function used by both Gradio and LangServe interfaces."""
171
+ start_time = datetime.now()
172
+ if not session_id:
173
+ session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
174
+
175
+ logger.info(f"Processing query in session {session_id}: {query[:100]}...")
176
+
177
+ try:
178
+ initial_state = {
179
+ "query": query,
180
+ "context": "",
181
+ "result": "",
182
+ "reports_filter": reports_filter or "",
183
+ "sources_filter": sources_filter or "",
184
+ "subtype_filter": subtype_filter or "",
185
+ "year_filter": year_filter or "",
186
+ "metadata": {
187
+ "session_id": session_id,
188
+ "user_id": user_id,
189
+ "start_time": start_time.isoformat(),
190
+ "orchestrator": "hybrid_gradio_langserve"
191
+ }
192
+ }
193
+
194
+ final_state = compiled_graph.invoke(initial_state)
195
+ total_duration = (datetime.now() - start_time).total_seconds()
196
+
197
+ final_metadata = final_state.get("metadata", {})
198
+ final_metadata.update({
199
+ "total_duration_seconds": total_duration,
200
+ "end_time": datetime.now().isoformat(),
201
+ "pipeline_success": True
202
+ })
203
+
204
+ logger.info(f"Query processing completed in {total_duration:.2f}s for session {session_id}")
205
+
206
+ if return_metadata:
207
+ return {"result": final_state["result"], "metadata": final_metadata}
208
+ else:
209
+ return final_state["result"]
210
+
211
+ except Exception as e:
212
+ total_duration = (datetime.now() - start_time).total_seconds()
213
+ logger.error(f"Pipeline failed after {total_duration:.2f}s for session {session_id}: {str(e)}")
214
+
215
+ if return_metadata:
216
+ error_metadata = {
217
+ "session_id": session_id,
218
+ "total_duration_seconds": total_duration,
219
+ "pipeline_success": False,
220
+ "error": str(e)
221
+ }
222
+ return {"result": f"Error processing query: {str(e)}", "metadata": error_metadata}
223
+ else:
224
+ return f"Error processing query: {str(e)}"
225
+
226
+ # =============================================================================
227
+ # GRADIO INTERFACE (MCP ENDPOINTS)
228
+ # =============================================================================
229
 
230
+ # Gradio wrapper functions for MCP compatibility
231
+ def process_query_gradio(
232
  query: str,
233
  reports_filter: str = "",
234
  sources_filter: str = "",
235
  subtype_filter: str = "",
236
  year_filter: str = ""
237
  ) -> str:
238
+ """Gradio-compatible function that exposes MCP endpoints."""
239
+ return process_chatfed_query_core(
240
+ query=query,
241
+ reports_filter=reports_filter,
242
+ sources_filter=sources_filter,
243
+ subtype_filter=subtype_filter,
244
+ year_filter=year_filter,
245
+ session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
246
+ return_metadata=False
247
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
 
249
  def get_graph_visualization():
250
+ """Generate graph visualization for Gradio interface."""
251
+ try:
252
+ graph_png_bytes = compiled_graph.get_graph().draw_mermaid_png()
253
+ return Image.open(io.BytesIO(graph_png_bytes))
254
+ except Exception as e:
255
+ logger.error(f"Failed to generate graph visualization: {e}")
256
+ return None
257
 
258
+ # Create Gradio interface
259
+ def create_gradio_interface():
260
+ with gr.Blocks(title="ChatFed Orchestrator - MCP Endpoints") as demo:
261
+ gr.Markdown("# ChatFed Orchestrator")
262
+ gr.Markdown("**MCP Server Endpoints Available** - This interface provides MCP compatibility for ChatUI integration.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ with gr.Row():
265
+ with gr.Column(scale=1):
266
+ gr.Markdown("**Workflow Visualization**")
267
+ graph_display = gr.Image(
268
+ value=get_graph_visualization(),
269
+ label="LangGraph Workflow",
270
+ interactive=False,
271
+ height=300
272
+ )
273
+ refresh_graph_btn = gr.Button("πŸ”„ Refresh Graph", size="sm")
274
+ refresh_graph_btn.click(fn=get_graph_visualization, outputs=graph_display)
275
 
276
+ gr.Markdown("**πŸ”— MCP Integration**")
277
+ gr.Markdown("MCP endpoints are active and ready for ChatUI integration.")
 
278
 
279
+ with gr.Column(scale=2):
280
+ gr.Markdown("**MCP Endpoint Information**")
281
 
282
+ with gr.Accordion("MCP Usage", open=True):
283
+ gr.Markdown("""
284
+ **MCP Server Endpoint:** Available at `/gradio_api/mcp/sse`
285
+
286
+ **For ChatUI Integration:**
287
+ ```python
288
+ from gradio_client import Client
289
+
290
+ # Connect to orchestrator MCP endpoint
291
+ client = Client("https://your-space.hf.space")
292
+
293
+ # Basic usage
294
+ response = client.predict(
295
+ query="your question",
296
+ api_name="/process_query_gradio"
297
+ )
298
+
299
+ # With filters
300
+ response = client.predict(
301
+ query="your question",
302
+ reports_filter="annual_reports",
303
+ sources_filter="internal",
304
+ year_filter="2024",
305
+ api_name="/process_query_gradio"
306
+ )
307
+ ```
308
+ """)
309
+
310
+ with gr.Accordion("Test Interface", open=False):
311
+ # Test interface
312
+ with gr.Row():
313
+ with gr.Column():
314
+ query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
315
+ reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
316
+ sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
317
+ subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
318
+ year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
319
+ submit_btn = gr.Button("Submit", variant="primary")
320
 
321
+ with gr.Column():
322
+ output = gr.Textbox(label="Response", lines=10)
323
+
324
+ submit_btn.click(
325
+ fn=process_query_gradio,
326
+ inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
327
+ outputs=output
328
+ )
329
+
330
+ return demo
331
 
332
+ # =============================================================================
333
+ # CHATUI STREAMING ADAPTER
334
+ # =============================================================================
335
 
336
+ async def chatui_streaming_adapter(data: ChatUIStreamInput) -> AsyncGenerator[str, None]:
337
+ """
338
+ Streaming adapter for ChatUI integration.
339
+ ChatUI expects streaming responses.
340
+ """
341
+ try:
342
+ logger.info(f"ChatUI streaming request: {data.text[:100]}...")
343
+
344
+ # Process the query using your core function
345
+ result = process_chatfed_query_core(
346
+ query=data.text,
347
+ session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
348
+ return_metadata=False
349
+ )
350
+
351
+ # Stream the response word by word or chunk by chunk
352
+ words = result.split()
353
+ for i, word in enumerate(words):
354
+ if i == 0:
355
+ yield word
356
+ else:
357
+ yield f" {word}"
358
+ # Small delay to simulate streaming
359
+ await asyncio.sleep(0.01)
360
+
361
+ except Exception as e:
362
+ logger.error(f"ChatUI streaming error: {str(e)}")
363
+ yield f"Error processing request: {str(e)}"
364
+
365
+ def chatui_non_streaming_adapter(data: ChatUIStreamInput):
366
+ """
367
+ Non-streaming adapter for ChatUI (fallback).
368
+ """
369
+ try:
370
+ logger.info(f"ChatUI non-streaming request: {data.text[:100]}...")
371
+
372
+ result = process_chatfed_query_core(
373
+ query=data.text,
374
+ session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
375
+ return_metadata=False
376
+ )
377
+ return {"content": result}
378
+ except Exception as e:
379
+ logger.error(f"ChatUI adapter error: {str(e)}")
380
+ return {"content": f"Error processing request: {str(e)}"}
381
+
382
+ # =============================================================================
383
+ # LANGSERVE API (TELEMETRY)
384
+ # =============================================================================
385
+
386
+ def process_chatfed_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
387
+ """LangServe function with full metadata return."""
388
+ result = process_chatfed_query_core(
389
+ query=input_data["query"],
390
+ reports_filter=input_data.get("reports_filter", ""),
391
+ sources_filter=input_data.get("sources_filter", ""),
392
+ subtype_filter=input_data.get("subtype_filter", ""),
393
+ year_filter=input_data.get("year_filter", ""),
394
+ session_id=input_data.get("session_id"),
395
+ user_id=input_data.get("user_id"),
396
+ return_metadata=True
397
+ )
398
+ return ChatFedOutput(result=result["result"], metadata=result["metadata"])
399
+
400
+ def chatui_adapter(data: ChatUIInput):
401
+ """
402
+ Adapter to allow ChatUI to send full chat history.
403
+ We extract the latest user message for ChatFed.
404
+ """
405
+ last_user_msg = next(m.content for m in reversed(data.messages) if m.role == "user")
406
+ result = process_chatfed_query_core(query=last_user_msg)
407
+ return {"result": result, "metadata": {"source": "chatfed-langserve-adapter"}}
408
+
409
+ @asynccontextmanager
410
+ async def lifespan(app: FastAPI):
411
+ logger.info("πŸš€ Hybrid ChatFed Orchestrator starting up...")
412
+ logger.info("βœ… LangGraph compiled successfully")
413
+ logger.info("πŸ”— MCP endpoints will be available via Gradio")
414
+ logger.info("πŸ“Š Enhanced API available via LangServe")
415
+ logger.info("🎯 ChatUI streaming integration enabled")
416
+ yield
417
+ logger.info("πŸ›‘ Orchestrator shutting down...")
418
+
419
+ # Create FastAPI app with docs disabled
420
+ app = FastAPI(
421
+ title="ChatFed Orchestrator - Enhanced API",
422
+ version="1.0.0",
423
+ description="Enhanced API with observability. MCP endpoints available via Gradio interface.",
424
+ lifespan=lifespan,
425
+ docs_url=None, # Disable /docs endpoint
426
+ redoc_url=None # Disable /redoc endpoint
427
+ )
428
+
429
+ # Health check
430
+ @app.get("/health")
431
+ async def health_check():
432
+ return {
433
+ "status": "healthy",
434
+ "mcp_endpoints": "available_via_gradio",
435
+ "enhanced_api": "available_via_langserve",
436
+ "chatui_integration": "enabled"
437
+ }
438
+
439
+ # Add root endpoint
440
+ @app.get("/")
441
+ async def root():
442
+ return {
443
+ "message": "ChatFed Orchestrator API",
444
+ "version": "1.0.0",
445
+ "endpoints": {
446
+ "health": "/health",
447
+ "chatfed": "/chatfed",
448
+ "chatfed-chatui": "/chatfed-chatui",
449
+ "chatfed-ui-stream": "/chatfed-ui-stream", # New for ChatUI streaming
450
+ "chatfed-ui": "/chatfed-ui", # New fallback
451
+ "process_query": "/process_query"
452
+ },
453
+ "gradio_interface": "http://localhost:7861/",
454
+ "mcp_endpoints": "http://localhost:7861/gradio_api/mcp/sse",
455
+ "note": "LangServe telemetry enabled - ChatUI integration available via /chatfed-ui-stream"
456
+ }
457
+
458
+ # =============================================================================
459
+ # ADD LANGSERVE ROUTES
460
+ # =============================================================================
461
+
462
+ # Convert functions to Runnables
463
+ process_chatfed_query_runnable = RunnableLambda(process_chatfed_query_langserve)
464
+ chatui_adapter_runnable = RunnableLambda(chatui_adapter)
465
+ chatui_streaming_runnable = RunnableLambda(chatui_streaming_adapter)
466
+ chatui_non_streaming_runnable = RunnableLambda(chatui_non_streaming_adapter)
467
+
468
+ # Add routes with explicit input/output schemas
469
+ add_routes(
470
+ app,
471
+ process_chatfed_query_runnable,
472
+ path="/chatfed",
473
+ input_type=ChatFedInput,
474
+ output_type=ChatFedOutput
475
+ )
476
+
477
+ # Original ChatUI-compatible LangServe route
478
+ add_routes(
479
+ app,
480
+ chatui_adapter_runnable,
481
+ path="/chatfed-chatui",
482
+ input_type=ChatUIInput
483
+ )
484
+
485
+ # NEW: ChatUI streaming route (matches your ChatUI config)
486
+ add_routes(
487
+ app,
488
+ chatui_streaming_runnable,
489
+ path="/chatfed-ui-stream",
490
+ input_type=ChatUIStreamInput,
491
+ enable_feedback_endpoint=True,
492
+ enable_public_trace_link_endpoint=True,
493
+ )
494
+
495
+ # NEW: ChatUI non-streaming fallback route
496
+ add_routes(
497
+ app,
498
+ chatui_non_streaming_runnable,
499
+ path="/chatfed-ui",
500
+ input_type=ChatUIStreamInput,
501
+ output_type=ChatUIStreamOutput,
502
+ enable_feedback_endpoint=True,
503
+ enable_public_trace_link_endpoint=True,
504
+ )
505
+
506
+ # Backward compatibility endpoint
507
+ @app.post("/process_query")
508
+ async def process_query_endpoint(
509
+ query: str,
510
+ reports_filter: str = "",
511
+ sources_filter: str = "",
512
+ subtype_filter: str = "",
513
+ year_filter: str = "",
514
+ session_id: Optional[str] = None,
515
+ user_id: Optional[str] = None
516
+ ):
517
+ """Backward compatibility endpoint."""
518
+ return process_chatfed_query_core(
519
+ query=query,
520
+ reports_filter=reports_filter,
521
+ sources_filter=sources_filter,
522
+ subtype_filter=subtype_filter,
523
+ year_filter=year_filter,
524
+ session_id=session_id,
525
+ user_id=user_id,
526
+ return_metadata=False
527
+ )
528
+
529
+ # =============================================================================
530
+ # MAIN APPLICATION LAUNCHER
531
+ # =============================================================================
532
+
533
+ def run_gradio_server():
534
+ """Run Gradio server in a separate thread for MCP endpoints."""
535
+ demo = create_gradio_interface()
536
  demo.launch(
537
  server_name="0.0.0.0",
538
+ server_port=7861, # Different port from FastAPI
539
+ mcp_server=True,
540
+ show_error=True,
541
+ share=False,
542
+ quiet=True
543
+ )
544
+
545
+ if __name__ == "__main__":
546
+ # Start Gradio server in background thread for MCP endpoints
547
+ gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
548
+ gradio_thread.start()
549
+ logger.info("πŸ”— Gradio MCP server started on port 7861")
550
+
551
+ # Start FastAPI server for enhanced API
552
+ host = os.getenv("HOST", "0.0.0.0")
553
+ port = int(os.getenv("PORT", "7860"))
554
+
555
+ logger.info(f"πŸš€ Starting FastAPI server on {host}:{port}")
556
+ logger.info("πŸ“Š Enhanced API with LangServe telemetry available")
557
+ logger.info("πŸ”— MCP endpoints available via Gradio on port 7861")
558
+ logger.info("🎯 ChatUI streaming integration ready at /chatfed-ui-stream")
559
+
560
+ uvicorn.run(
561
+ app,
562
+ host=host,
563
+ port=port,
564
+ log_level="info",
565
+ access_log=True
566
  )