from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import os from dotenv import load_dotenv from typing import Optional, List, Dict import json from pydantic import BaseModel # Import project modules from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps from src.agents.generic_agent import generic_agent from src.hopter.client import Hopter, Environment from src.services.generate_mask import GenerateMaskService from src.utils import upload_image # Load environment variables load_dotenv() app = FastAPI(title="Image Edit API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) class EditRequest(BaseModel): edit_instruction: str image_url: Optional[str] = None class MessageContent(BaseModel): type: str text: Optional[str] = None image_url: Optional[Dict[str, str]] = None class Message(BaseModel): content: List[MessageContent] @app.post("/test/stream") async def test(query: str): async def stream_messages(): async with generic_agent.run_stream(query) as result: async for message in result.stream(debounce_by=0.01): yield json.dumps(message) + "\n" return StreamingResponse(stream_messages(), media_type="text/plain") @app.post("/edit") async def edit_image(request: EditRequest): """ Edit an image based on the provided instruction. Returns the URL of the edited image. """ try: # Initialize services hopter = Hopter( api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING ) mask_service = GenerateMaskService(hopter=hopter) # Initialize dependencies deps = ImageEditDeps( edit_instruction=request.edit_instruction, image_url=request.image_url, hopter_client=hopter, mask_service=mask_service, ) # Create messages messages = [{"type": "text", "text": request.edit_instruction}] if request.image_url: messages.append( {"type": "image_url", "image_url": {"url": request.image_url}} ) # Run the agent result = await image_edit_agent.run(messages, deps=deps) # Return the result return {"edited_image_url": result.edited_image_url} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/edit/stream") async def edit_image_stream(request: EditRequest): """ Edit an image based on the provided instruction. Streams the agent's responses back to the client. """ try: # Initialize services hopter = Hopter( api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING ) mask_service = GenerateMaskService(hopter=hopter) # Initialize dependencies deps = ImageEditDeps( edit_instruction=request.edit_instruction, image_url=request.image_url, hopter_client=hopter, mask_service=mask_service, ) # Create messages messages = [{"type": "text", "text": request.edit_instruction}] if request.image_url: messages.append( {"type": "image_url", "image_url": {"url": request.image_url}} ) async def stream_generator(): async with image_edit_agent.run_stream(messages, deps=deps) as result: async for message in result.stream(): # Convert message to JSON and yield yield json.dumps(message) + "\n" return StreamingResponse(stream_generator(), media_type="application/x-ndjson") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload") async def upload_image_file(file: UploadFile = File(...)): """ Upload an image file and return its URL. """ try: # Save the uploaded file to a temporary location temp_file_path = f"/tmp/{file.filename}" with open(temp_file_path, "wb") as buffer: buffer.write(await file.read()) # Upload the image to Google Cloud Storage image_url = upload_image(temp_file_path) # Remove the temporary file os.remove(temp_file_path) return {"image_url": image_url} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """ Health check endpoint. """ return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)