Spaces:
Running
Running
from pydantic_ai import Agent, RunContext | |
from pydantic_ai.models.openai import OpenAIModel | |
from dotenv import load_dotenv | |
import os | |
import asyncio | |
from dataclasses import dataclass | |
from typing import Optional | |
import logfire | |
from src.services.generate_mask import GenerateMaskService | |
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput | |
from src.services.image_uploader import ImageUploader | |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image | |
import base64 | |
import tempfile | |
from PIL import Image | |
load_dotenv() | |
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN")) | |
logfire.instrument_openai() | |
system_prompt = """ | |
I will give you an editing instruction of the image. | |
if the edit instruction involved modifying parts of the image, please generate a mask for it. | |
if images are not provided, ask the user to provide an image. | |
""" | |
class ImageEditDeps: | |
edit_instruction: str | |
hopter_client: Hopter | |
mask_service: GenerateMaskService | |
image_url: Optional[str] = None | |
model = OpenAIModel( | |
"gpt-4o", | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
class MaskGenerationResult: | |
mask_image_base64: str | |
class EditImageResult: | |
edited_image_url: str | |
mask_generation_agent = Agent( | |
model, | |
system_prompt=system_prompt, | |
deps_type=ImageEditDeps | |
) | |
def upload_image_from_base64(base64_image: str) -> str: | |
image_format = base64_image.split(",")[0] | |
image_data = base64.b64decode(base64_image.split(",")[1]) | |
suffix = ".jpg" if image_format == "image/jpeg" else ".png" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
temp_filename = temp_file.name | |
with open(temp_filename, "wb") as f: | |
f.write(image_data) | |
return upload_image(temp_filename) | |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
""" | |
Use this tool to edit an object in the image. for example: | |
- remove the pole | |
- replace the dog with a cat | |
- change the background to a beach | |
- remove the person in the image | |
- change the hair color to red | |
- change the hat to a cap | |
""" | |
edit_instruction = ctx.deps.edit_instruction | |
image_url = ctx.deps.image_url | |
mask_service = ctx.deps.mask_service | |
hopter_client = ctx.deps.hopter_client | |
image_uri = download_image_to_data_uri(image_url) | |
# Generate mask | |
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url) | |
mask = mask_service.generate_mask(mask_instruction, image_uri) | |
# Magic replace | |
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption) | |
result = hopter_client.magic_replace(input) | |
uploaded_image = upload_image_from_base64(result.base64_image) | |
return EditImageResult(edited_image_url=uploaded_image) | |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
""" | |
run super resolution, upscale, or enhance the image | |
""" | |
image_url = ctx.deps.image_url | |
hopter_client = ctx.deps.hopter_client | |
image_uri = download_image_to_data_uri(image_url) | |
input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False) | |
result = hopter_client.super_resolution(input) | |
uploaded_image = upload_image_from_base64(result.scaled_image) | |
return EditImageResult(edited_image_url=uploaded_image) | |
async def main(): | |
image_file_path = "./assets/lakeview.jpg" | |
image_url = image_path_to_uri(image_file_path) | |
prompt = "remove the light post" | |
messages = [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": image_url | |
} | |
} | |
] | |
# Initialize services | |
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING) | |
mask_service = GenerateMaskService(hopter=hopter) | |
# Initialize dependencies | |
deps = ImageEditDeps( | |
edit_instruction=prompt, | |
image_url=image_url, | |
hopter_client=hopter, | |
mask_service=mask_service | |
) | |
async with mask_generation_agent.run_stream( | |
messages, | |
deps=deps | |
) as result: | |
async for message in result.stream(): | |
print(message) | |
if __name__ == "__main__": | |
asyncio.run(main()) |