chat-image-edit / server.py
simonlee-cb's picture
refactor: formatting
fcb8f25
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
from typing import Optional, List, Dict
import json
from pydantic import BaseModel
# Import project modules
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps
from src.agents.generic_agent import generic_agent
from src.hopter.client import Hopter, Environment
from src.services.generate_mask import GenerateMaskService
from src.utils import upload_image
# Load environment variables
load_dotenv()
app = FastAPI(title="Image Edit API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
class EditRequest(BaseModel):
edit_instruction: str
image_url: Optional[str] = None
class MessageContent(BaseModel):
type: str
text: Optional[str] = None
image_url: Optional[Dict[str, str]] = None
class Message(BaseModel):
content: List[MessageContent]
@app.post("/test/stream")
async def test(query: str):
async def stream_messages():
async with generic_agent.run_stream(query) as result:
async for message in result.stream(debounce_by=0.01):
yield json.dumps(message) + "\n"
return StreamingResponse(stream_messages(), media_type="text/plain")
@app.post("/edit")
async def edit_image(request: EditRequest):
"""
Edit an image based on the provided instruction.
Returns the URL of the edited image.
"""
try:
# Initialize services
hopter = Hopter(
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
# Initialize dependencies
deps = ImageEditDeps(
edit_instruction=request.edit_instruction,
image_url=request.image_url,
hopter_client=hopter,
mask_service=mask_service,
)
# Create messages
messages = [{"type": "text", "text": request.edit_instruction}]
if request.image_url:
messages.append(
{"type": "image_url", "image_url": {"url": request.image_url}}
)
# Run the agent
result = await image_edit_agent.run(messages, deps=deps)
# Return the result
return {"edited_image_url": result.edited_image_url}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/edit/stream")
async def edit_image_stream(request: EditRequest):
"""
Edit an image based on the provided instruction.
Streams the agent's responses back to the client.
"""
try:
# Initialize services
hopter = Hopter(
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
# Initialize dependencies
deps = ImageEditDeps(
edit_instruction=request.edit_instruction,
image_url=request.image_url,
hopter_client=hopter,
mask_service=mask_service,
)
# Create messages
messages = [{"type": "text", "text": request.edit_instruction}]
if request.image_url:
messages.append(
{"type": "image_url", "image_url": {"url": request.image_url}}
)
async def stream_generator():
async with image_edit_agent.run_stream(messages, deps=deps) as result:
async for message in result.stream():
# Convert message to JSON and yield
yield json.dumps(message) + "\n"
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_image_file(file: UploadFile = File(...)):
"""
Upload an image file and return its URL.
"""
try:
# Save the uploaded file to a temporary location
temp_file_path = f"/tmp/{file.filename}"
with open(temp_file_path, "wb") as buffer:
buffer.write(await file.read())
# Upload the image to Google Cloud Storage
image_url = upload_image(temp_file_path)
# Remove the temporary file
os.remove(temp_file_path)
return {"image_url": image_url}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""
Health check endpoint.
"""
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)