simonlee-cb commited on
Commit
2a2c2ad
·
1 Parent(s): 583b7ad

feat: add console streaming for debugging

Browse files
server.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
2
+ from fastapi.responses import StreamingResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import asyncio
5
+ import os
6
+ from dotenv import load_dotenv
7
+ from typing import Optional, List, Dict, Any
8
+ import json
9
+ from pydantic import BaseModel
10
+
11
+ # Import project modules
12
+ from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps
13
+ from src.agents.generic_agent import generic_agent
14
+ from src.hopter.client import Hopter, Environment
15
+ from src.services.generate_mask import GenerateMaskService
16
+ from src.utils import upload_file_to_base64, upload_image
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ app = FastAPI(title="Image Edit API")
22
+
23
+ # Add CORS middleware
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"], # Allows all origins
27
+ allow_credentials=True,
28
+ allow_methods=["*"], # Allows all methods
29
+ allow_headers=["*"], # Allows all headers
30
+ )
31
+
32
+ class EditRequest(BaseModel):
33
+ edit_instruction: str
34
+ image_url: Optional[str] = None
35
+
36
+ class MessageContent(BaseModel):
37
+ type: str
38
+ text: Optional[str] = None
39
+ image_url: Optional[Dict[str, str]] = None
40
+
41
+ class Message(BaseModel):
42
+ content: List[MessageContent]
43
+
44
+
45
+ @app.post("/test/stream")
46
+ async def test(query: str):
47
+ async def stream_messages():
48
+ async with generic_agent.run_stream(query) as result:
49
+ async for message in result.stream(debounce_by=0.01):
50
+ yield json.dumps(message) + "\n"
51
+
52
+ return StreamingResponse(stream_messages(), media_type="text/plain")
53
+
54
+ @app.post("/edit")
55
+ async def edit_image(request: EditRequest):
56
+ """
57
+ Edit an image based on the provided instruction.
58
+ Returns the URL of the edited image.
59
+ """
60
+ try:
61
+ # Initialize services
62
+ hopter = Hopter(
63
+ api_key=os.environ.get("HOPTER_API_KEY"),
64
+ environment=Environment.STAGING
65
+ )
66
+ mask_service = GenerateMaskService(hopter=hopter)
67
+
68
+ # Initialize dependencies
69
+ deps = ImageEditDeps(
70
+ edit_instruction=request.edit_instruction,
71
+ image_url=request.image_url,
72
+ hopter_client=hopter,
73
+ mask_service=mask_service
74
+ )
75
+
76
+ # Create messages
77
+ messages = [
78
+ {
79
+ "type": "text",
80
+ "text": request.edit_instruction
81
+ }
82
+ ]
83
+
84
+ if request.image_url:
85
+ messages.append({
86
+ "type": "image_url",
87
+ "image_url": {
88
+ "url": request.image_url
89
+ }
90
+ })
91
+
92
+ # Run the agent
93
+ result = await image_edit_agent.run(messages, deps=deps)
94
+
95
+ # Return the result
96
+ return {"edited_image_url": result.edited_image_url}
97
+
98
+ except Exception as e:
99
+ raise HTTPException(status_code=500, detail=str(e))
100
+
101
+ @app.post("/edit/stream")
102
+ async def edit_image_stream(request: EditRequest):
103
+ """
104
+ Edit an image based on the provided instruction.
105
+ Streams the agent's responses back to the client.
106
+ """
107
+ try:
108
+ # Initialize services
109
+ hopter = Hopter(
110
+ api_key=os.environ.get("HOPTER_API_KEY"),
111
+ environment=Environment.STAGING
112
+ )
113
+ mask_service = GenerateMaskService(hopter=hopter)
114
+
115
+ # Initialize dependencies
116
+ deps = ImageEditDeps(
117
+ edit_instruction=request.edit_instruction,
118
+ image_url=request.image_url,
119
+ hopter_client=hopter,
120
+ mask_service=mask_service
121
+ )
122
+
123
+ # Create messages
124
+ messages = [
125
+ {
126
+ "type": "text",
127
+ "text": request.edit_instruction
128
+ }
129
+ ]
130
+
131
+ if request.image_url:
132
+ messages.append({
133
+ "type": "image_url",
134
+ "image_url": {
135
+ "url": request.image_url
136
+ }
137
+ })
138
+
139
+ async def stream_generator():
140
+ async with image_edit_agent.run_stream(messages, deps=deps) as result:
141
+ async for message in result.stream():
142
+ # Convert message to JSON and yield
143
+ yield json.dumps(message) + "\n"
144
+
145
+ return StreamingResponse(
146
+ stream_generator(),
147
+ media_type="application/x-ndjson"
148
+ )
149
+
150
+ except Exception as e:
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
+ @app.post("/upload")
154
+ async def upload_image_file(file: UploadFile = File(...)):
155
+ """
156
+ Upload an image file and return its URL.
157
+ """
158
+ try:
159
+ # Save the uploaded file to a temporary location
160
+ temp_file_path = f"/tmp/{file.filename}"
161
+ with open(temp_file_path, "wb") as buffer:
162
+ buffer.write(await file.read())
163
+
164
+ # Upload the image to Google Cloud Storage
165
+ image_url = upload_image(temp_file_path)
166
+
167
+ # Remove the temporary file
168
+ os.remove(temp_file_path)
169
+
170
+ return {"image_url": image_url}
171
+
172
+ except Exception as e:
173
+ raise HTTPException(status_code=500, detail=str(e))
174
+
175
+ @app.get("/health")
176
+ async def health_check():
177
+ """
178
+ Health check endpoint.
179
+ """
180
+ return {"status": "ok"}
181
+
182
+ if __name__ == "__main__":
183
+ import uvicorn
184
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/agents/generic_agent.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_ai import Agent, RunContext
2
+ from pydantic_ai.models.openai import OpenAIModel
3
+ from dotenv import load_dotenv
4
+ import os
5
+
6
+ load_dotenv()
7
+
8
+ model = OpenAIModel(
9
+ "gpt-4o",
10
+ api_key=os.environ.get("OPENAI_API_KEY")
11
+ )
12
+
13
+ system_prompt = """
14
+ You are a helpful assistant that can answer questions and help with tasks.
15
+ """
16
+
17
+ generic_agent = Agent(
18
+ model=model,
19
+ system_prompt=system_prompt,
20
+ tools=[],
21
+ )
22
+
23
+
stream_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from rich.console import Console
4
+ from rich.live import Live
5
+ from rich.markdown import Markdown
6
+ from rich.panel import Panel
7
+ from rich.text import Text
8
+
9
+ class StreamResponseHandler:
10
+ """
11
+ A utility class for handling streaming responses from API endpoints.
12
+ Provides rich formatting and real-time updates of the response content.
13
+ """
14
+
15
+ def __init__(self, console=None):
16
+ """
17
+ Initialize the stream response handler.
18
+
19
+ Args:
20
+ console (Console, optional): A Rich console instance. If not provided, a new one will be created.
21
+ """
22
+ self.console = console or Console()
23
+
24
+ def check_server_health(self, health_url="http://localhost:8000/health"):
25
+ """
26
+ Check if the server is running and accessible.
27
+
28
+ Args:
29
+ health_url (str, optional): The URL to check server health. Defaults to "http://localhost:8000/health".
30
+
31
+ Returns:
32
+ bool: True if the server is running and accessible, False otherwise.
33
+ """
34
+ try:
35
+ self.console.print("Checking server health...", style="bold yellow")
36
+ response = requests.get(health_url)
37
+ if response.status_code == 200:
38
+ self.console.print("[bold green]✓ Server is running and accessible.[/]")
39
+ return True
40
+ else:
41
+ self.console.print(f"[bold red]✗ Server health check failed[/] with status code: {response.status_code}")
42
+ return False
43
+ except requests.exceptions.ConnectionError:
44
+ self.console.print("[bold red]✗ Error:[/] Could not connect to the server. Make sure it's running.")
45
+ return False
46
+ except Exception as e:
47
+ self.console.print(f"[bold red]✗ Error checking server health:[/] {e}")
48
+ return False
49
+
50
+ def stream_response(self, url, payload=None, params=None, method="POST", title="AI Response"):
51
+ """
52
+ Send a request to an endpoint and stream the output to the terminal.
53
+
54
+ Args:
55
+ url (str): The URL of the endpoint to send the request to.
56
+ payload (dict, optional): The JSON payload to send in the request body. Defaults to None.
57
+ params (dict, optional): The query parameters to send in the request. Defaults to None.
58
+ method (str, optional): The HTTP method to use. Defaults to "POST".
59
+ title (str, optional): The title to display in the panel. Defaults to "AI Response".
60
+
61
+ Returns:
62
+ bool: True if the streaming was successful, False otherwise.
63
+ """
64
+ # Display request information
65
+ self.console.print(f"Sending request to [bold cyan]{url}[/]")
66
+ if payload:
67
+ self.console.print("Payload:", style="bold")
68
+ self.console.print(json.dumps(payload, indent=2))
69
+ if params:
70
+ self.console.print("Parameters:", style="bold")
71
+ self.console.print(json.dumps(params, indent=2))
72
+
73
+ try:
74
+ # Prepare the request
75
+ request_kwargs = {
76
+ "stream": True
77
+ }
78
+ if payload:
79
+ request_kwargs["json"] = payload
80
+ if params:
81
+ request_kwargs["params"] = params
82
+
83
+ # Make the request
84
+ with getattr(requests, method.lower())(url, **request_kwargs) as response:
85
+ # Check if the request was successful
86
+ if response.status_code != 200:
87
+ self.console.print(f"[bold red]Error:[/] Received status code {response.status_code}")
88
+ self.console.print(f"Response: {response.text}")
89
+ return False
90
+
91
+ # Initialize an empty response text
92
+ full_response = ""
93
+
94
+ # Use Rich's Live display to update the content in place
95
+ with Live(Panel("Waiting for response...", title=title, border_style="blue"), refresh_per_second=10) as live:
96
+ # Process the streaming response
97
+ for line in response.iter_lines():
98
+ if line:
99
+ # Decode the line and parse it as JSON
100
+ decoded_line = line.decode('utf-8')
101
+ try:
102
+ # Parse the JSON
103
+ data = json.loads(decoded_line)
104
+
105
+ # Extract and display the content
106
+ if isinstance(data, dict):
107
+ if "content" in data:
108
+ for content in data["content"]:
109
+ if content.get("type") == "text":
110
+ text_content = content.get("text", "")
111
+ # Append to the full response
112
+ full_response += text_content
113
+ # Update the live display with the current full response
114
+ live.update(Panel(Markdown(full_response), title=title, border_style="green"))
115
+ elif content.get("type") == "image_url":
116
+ image_url = content.get("image_url", {}).get("url", "")
117
+ # Add a note about the image URL
118
+ image_note = f"\n\n[Image URL: {image_url}]"
119
+ full_response += image_note
120
+ live.update(Panel(Markdown(full_response), title=title, border_style="green"))
121
+ elif "edited_image_url" in data:
122
+ # Handle edited image URL from edit endpoint
123
+ image_url = data.get("edited_image_url", "")
124
+ image_note = f"\n\n[Edited Image URL: {image_url}]"
125
+ full_response += image_note
126
+ live.update(Panel(Markdown(full_response), title=title, border_style="green"))
127
+ else:
128
+ # For other types of data, just show the JSON
129
+ live.update(Panel(Text(json.dumps(data, indent=2)), title="Raw JSON Response", border_style="yellow"))
130
+ else:
131
+ live.update(Panel(Text(decoded_line), title="Raw Response", border_style="yellow"))
132
+ except json.JSONDecodeError:
133
+ # If it's not valid JSON, just show the raw line
134
+ live.update(Panel(Text(f"Raw response: {decoded_line}"), title="Invalid JSON", border_style="red"))
135
+
136
+ self.console.print("[bold green]Stream completed.[/]")
137
+ return True
138
+
139
+ except requests.exceptions.ConnectionError:
140
+ self.console.print(f"[bold red]Error:[/] Could not connect to the server at {url}", style="red")
141
+ self.console.print("Make sure the server is running and accessible.", style="red")
142
+ return False
143
+ except requests.exceptions.RequestException as e:
144
+ self.console.print(f"[bold red]Error:[/] {e}", style="red")
145
+ return False
test_edit_stream.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import requests
5
+ import json
6
+ from dotenv import load_dotenv
7
+ from stream_utils import StreamResponseHandler
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ def get_default_image():
13
+ """Get the default image path and convert it to a data URI."""
14
+ image_path = "./assets/lakeview.jpg"
15
+ if os.path.exists(image_path):
16
+ try:
17
+ from src.utils import image_path_to_uri
18
+ image_uri = image_path_to_uri(image_path)
19
+ print(f"Using default image: {image_path}")
20
+ return image_uri
21
+ except Exception as e:
22
+ print(f"Error converting image to URI: {e}")
23
+ return None
24
+ else:
25
+ print(f"Warning: Default image not found at {image_path}")
26
+ return None
27
+
28
+ def upload_image(handler, image_path):
29
+ """
30
+ Upload an image to the server.
31
+
32
+ Args:
33
+ handler (StreamResponseHandler): The stream response handler.
34
+ image_path (str): Path to the image file to upload.
35
+
36
+ Returns:
37
+ str: The URL of the uploaded image, or None if upload failed.
38
+ """
39
+ if not os.path.exists(image_path):
40
+ handler.console.print(f"[bold red]Error:[/] Image file not found at {image_path}")
41
+ return None
42
+
43
+ try:
44
+ handler.console.print(f"Uploading image: [bold]{image_path}[/]")
45
+ with open(image_path, 'rb') as f:
46
+ files = {'file': (os.path.basename(image_path), f)}
47
+ response = requests.post("http://localhost:8000/upload", files=files)
48
+ if response.status_code == 200:
49
+ image_url = response.json().get("image_url")
50
+ handler.console.print(f"Image uploaded successfully. URL: [bold green]{image_url}[/]")
51
+ return image_url
52
+ else:
53
+ handler.console.print(f"[bold red]Failed to upload image.[/] Status code: {response.status_code}")
54
+ handler.console.print(f"Response: {response.text}")
55
+ return None
56
+ except Exception as e:
57
+ handler.console.print(f"[bold red]Error uploading image:[/] {e}")
58
+ return None
59
+
60
+ def main():
61
+ # Create a stream response handler
62
+ handler = StreamResponseHandler()
63
+
64
+ # Parse command line arguments
65
+ parser = argparse.ArgumentParser(description="Test the image edit streaming API.")
66
+ parser.add_argument("--instruction", "-i", required=True, help="The edit instruction.")
67
+ parser.add_argument("--image", "-img", help="The URL of the image to edit.")
68
+ parser.add_argument("--upload", "-u", help="Path to an image file to upload first.")
69
+
70
+ args = parser.parse_args()
71
+
72
+ # Check if the server is running
73
+ if not handler.check_server_health():
74
+ sys.exit(1)
75
+
76
+ image_url = args.image
77
+
78
+ # If upload is specified, upload the image first
79
+ if args.upload:
80
+ image_url = upload_image(handler, args.upload)
81
+ if not image_url:
82
+ handler.console.print("[yellow]Warning:[/] Failed to upload image. Continuing without image URL.")
83
+
84
+ # Use the default image if no image URL is provided
85
+ if not image_url:
86
+ image_url = get_default_image()
87
+ if not image_url:
88
+ handler.console.print("[yellow]No image URL provided and default image not available.[/]")
89
+ handler.console.print("The agent may ask for an image if needed.")
90
+
91
+ # Prepare the payload for the edit request
92
+ payload = {
93
+ "edit_instruction": args.instruction
94
+ }
95
+
96
+ if image_url:
97
+ payload["image_url"] = image_url
98
+
99
+ # Stream the edit request
100
+ endpoint_url = "http://localhost:8000/edit/stream"
101
+ handler.stream_response(endpoint_url, payload=payload, title="Image Edit Response")
102
+
103
+ if __name__ == "__main__":
104
+ main()
test_generic_stream.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from dotenv import load_dotenv
4
+ from stream_utils import StreamResponseHandler
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ def main():
10
+ # Create a console for rich output
11
+ handler = StreamResponseHandler()
12
+
13
+ # Parse command line arguments
14
+ parser = argparse.ArgumentParser(description="Test the generic agent streaming API.")
15
+ parser.add_argument("--query", "-q", required=True, help="The query or message to send to the generic agent.")
16
+
17
+ args = parser.parse_args()
18
+
19
+ # Check if the server is running
20
+ if not handler.check_server_health():
21
+ sys.exit(1)
22
+
23
+ # Stream the generic request
24
+ endpoint_url = "http://localhost:8000/test/stream"
25
+ params = {"query": args.query}
26
+ handler.stream_response(endpoint_url, params=params, title="Generic Agent Response")
27
+
28
+ if __name__ == "__main__":
29
+ main()