Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import os | |
from dotenv import load_dotenv | |
from typing import Optional, List, Dict | |
import json | |
from pydantic import BaseModel | |
# Import project modules | |
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps | |
from src.agents.generic_agent import generic_agent | |
from src.hopter.client import Hopter, Environment | |
from src.services.generate_mask import GenerateMaskService | |
from src.utils import upload_image | |
# Load environment variables | |
load_dotenv() | |
app = FastAPI(title="Image Edit API") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
class EditRequest(BaseModel): | |
edit_instruction: str | |
image_url: Optional[str] = None | |
class MessageContent(BaseModel): | |
type: str | |
text: Optional[str] = None | |
image_url: Optional[Dict[str, str]] = None | |
class Message(BaseModel): | |
content: List[MessageContent] | |
async def test(query: str): | |
async def stream_messages(): | |
async with generic_agent.run_stream(query) as result: | |
async for message in result.stream(debounce_by=0.01): | |
yield json.dumps(message) + "\n" | |
return StreamingResponse(stream_messages(), media_type="text/plain") | |
async def edit_image(request: EditRequest): | |
""" | |
Edit an image based on the provided instruction. | |
Returns the URL of the edited image. | |
""" | |
try: | |
# Initialize services | |
hopter = Hopter( | |
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
) | |
mask_service = GenerateMaskService(hopter=hopter) | |
# Initialize dependencies | |
deps = ImageEditDeps( | |
edit_instruction=request.edit_instruction, | |
image_url=request.image_url, | |
hopter_client=hopter, | |
mask_service=mask_service, | |
) | |
# Create messages | |
messages = [{"type": "text", "text": request.edit_instruction}] | |
if request.image_url: | |
messages.append( | |
{"type": "image_url", "image_url": {"url": request.image_url}} | |
) | |
# Run the agent | |
result = await image_edit_agent.run(messages, deps=deps) | |
# Return the result | |
return {"edited_image_url": result.edited_image_url} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def edit_image_stream(request: EditRequest): | |
""" | |
Edit an image based on the provided instruction. | |
Streams the agent's responses back to the client. | |
""" | |
try: | |
# Initialize services | |
hopter = Hopter( | |
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
) | |
mask_service = GenerateMaskService(hopter=hopter) | |
# Initialize dependencies | |
deps = ImageEditDeps( | |
edit_instruction=request.edit_instruction, | |
image_url=request.image_url, | |
hopter_client=hopter, | |
mask_service=mask_service, | |
) | |
# Create messages | |
messages = [{"type": "text", "text": request.edit_instruction}] | |
if request.image_url: | |
messages.append( | |
{"type": "image_url", "image_url": {"url": request.image_url}} | |
) | |
async def stream_generator(): | |
async with image_edit_agent.run_stream(messages, deps=deps) as result: | |
async for message in result.stream(): | |
# Convert message to JSON and yield | |
yield json.dumps(message) + "\n" | |
return StreamingResponse(stream_generator(), media_type="application/x-ndjson") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def upload_image_file(file: UploadFile = File(...)): | |
""" | |
Upload an image file and return its URL. | |
""" | |
try: | |
# Save the uploaded file to a temporary location | |
temp_file_path = f"/tmp/{file.filename}" | |
with open(temp_file_path, "wb") as buffer: | |
buffer.write(await file.read()) | |
# Upload the image to Google Cloud Storage | |
image_url = upload_image(temp_file_path) | |
# Remove the temporary file | |
os.remove(temp_file_path) | |
return {"image_url": image_url} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
""" | |
Health check endpoint. | |
""" | |
return {"status": "ok"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |