from pydantic_ai import Agent, RunContext from pydantic_ai.models.openai import OpenAIModel from dotenv import load_dotenv import os import asyncio from dataclasses import dataclass from typing import Optional import logfire from src.services.generate_mask import GenerateMaskService from src.hopter.client import ( Hopter, Environment, MagicReplaceInput, SuperResolutionInput, ) from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image import base64 import tempfile load_dotenv() logfire.configure(token=os.environ.get("LOGFIRE_TOKEN")) logfire.instrument_openai() system_prompt = """ I will give you an editing instruction of the image. if the edit instruction involved modifying parts of the image, please generate a mask for it. if images are not provided, ask the user to provide an image. """ @dataclass class ImageEditDeps: edit_instruction: str hopter_client: Hopter mask_service: GenerateMaskService image_url: Optional[str] = None model = OpenAIModel( "gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"), ) @dataclass class EditImageResult: edited_image_url: str image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps) def upload_image_from_base64(base64_image: str) -> str: image_format = base64_image.split(",")[0] image_data = base64.b64decode(base64_image.split(",")[1]) suffix = ".jpg" if image_format == "image/jpeg" else ".png" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: temp_filename = temp_file.name with open(temp_filename, "wb") as f: f.write(image_data) return upload_image(temp_filename) @image_edit_agent.tool async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult: """ Use this tool to edit an object in the image. for example: - remove the pole - replace the dog with a cat - change the background to a beach - remove the person in the image - change the hair color to red - change the hat to a cap """ edit_instruction = ctx.deps.edit_instruction image_url = ctx.deps.image_url mask_service = ctx.deps.mask_service hopter_client = ctx.deps.hopter_client image_uri = download_image_to_data_uri(image_url) # Generate mask mask_instruction = mask_service.get_mask_generation_instruction( edit_instruction, image_url ) mask = mask_service.generate_mask(mask_instruction, image_uri) # Magic replace input = MagicReplaceInput( image=image_uri, mask=mask, prompt=mask_instruction.target_caption ) result = hopter_client.magic_replace(input) uploaded_image = upload_image_from_base64(result.base64_image) return EditImageResult(edited_image_url=uploaded_image) @image_edit_agent.tool async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult: """ run super resolution, upscale, or enhance the image """ image_url = ctx.deps.image_url hopter_client = ctx.deps.hopter_client image_uri = download_image_to_data_uri(image_url) input = SuperResolutionInput( image_b64=image_uri, scale=4, use_face_enhancement=False ) result = hopter_client.super_resolution(input) uploaded_image = upload_image_from_base64(result.scaled_image) return EditImageResult(edited_image_url=uploaded_image) async def main(): image_file_path = "./assets/lakeview.jpg" image_url = image_path_to_uri(image_file_path) prompt = "remove the light post" messages = [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}, ] # Initialize services hopter = Hopter( api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING ) mask_service = GenerateMaskService(hopter=hopter) # Initialize dependencies deps = ImageEditDeps( edit_instruction=prompt, image_url=image_url, hopter_client=hopter, mask_service=mask_service, ) async with image_edit_agent.run_stream(messages, deps=deps) as result: async for message in result.stream(): print(message) if __name__ == "__main__": asyncio.run(main())