chat-image-edit / src /agents /mask_generation_agent.py
simonlee-cb's picture
feat: added examples
9e6acd9
raw
history blame
4.48 kB
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from dataclasses import dataclass
from typing import Optional
import logfire
from src.services.generate_mask import GenerateMaskService
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
from src.services.image_uploader import ImageUploader
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
import base64
import tempfile
from PIL import Image
load_dotenv()
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
logfire.instrument_openai()
system_prompt = """
I will give you an editing instruction of the image.
if the edit instruction involved modifying parts of the image, please generate a mask for it.
if images are not provided, ask the user to provide an image.
"""
@dataclass
class ImageEditDeps:
edit_instruction: str
hopter_client: Hopter
mask_service: GenerateMaskService
image_url: Optional[str] = None
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
@dataclass
class MaskGenerationResult:
mask_image_base64: str
@dataclass
class EditImageResult:
edited_image_url: str
mask_generation_agent = Agent(
model,
system_prompt=system_prompt,
deps_type=ImageEditDeps
)
def upload_image_from_base64(base64_image: str) -> str:
image_format = base64_image.split(",")[0]
image_data = base64.b64decode(base64_image.split(",")[1])
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_filename = temp_file.name
with open(temp_filename, "wb") as f:
f.write(image_data)
return upload_image(temp_filename)
@mask_generation_agent.tool
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
Use this tool to edit an object in the image. for example:
- remove the pole
- replace the dog with a cat
- change the background to a beach
- remove the person in the image
- change the hair color to red
- change the hat to a cap
"""
edit_instruction = ctx.deps.edit_instruction
image_url = ctx.deps.image_url
mask_service = ctx.deps.mask_service
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
# Generate mask
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
mask = mask_service.generate_mask(mask_instruction, image_uri)
# Magic replace
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
result = hopter_client.magic_replace(input)
uploaded_image = upload_image_from_base64(result.base64_image)
return EditImageResult(edited_image_url=uploaded_image)
@mask_generation_agent.tool
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
run super resolution, upscale, or enhance the image
"""
image_url = ctx.deps.image_url
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
result = hopter_client.super_resolution(input)
uploaded_image = upload_image_from_base64(result.scaled_image)
return EditImageResult(edited_image_url=uploaded_image)
async def main():
image_file_path = "./assets/lakeview.jpg"
image_url = image_path_to_uri(image_file_path)
prompt = "remove the light post"
messages = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# Initialize services
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
mask_service = GenerateMaskService(hopter=hopter)
# Initialize dependencies
deps = ImageEditDeps(
edit_instruction=prompt,
image_url=image_url,
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
async for message in result.stream():
print(message)
if __name__ == "__main__":
asyncio.run(main())