Spaces:
Running
Running
from pydantic_ai import Agent, RunContext | |
from pydantic_ai.models.openai import OpenAIModel | |
from dotenv import load_dotenv | |
import os | |
import asyncio | |
from dataclasses import dataclass | |
from typing import Optional | |
import logfire | |
from src.services.generate_mask import GenerateMaskService | |
from src.hopter.client import ( | |
Hopter, | |
Environment, | |
MagicReplaceInput, | |
SuperResolutionInput, | |
) | |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image | |
import base64 | |
import tempfile | |
load_dotenv() | |
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN")) | |
logfire.instrument_openai() | |
system_prompt = """ | |
I will give you an editing instruction of the image. | |
if the edit instruction involved modifying parts of the image, please generate a mask for it. | |
if images are not provided, ask the user to provide an image. | |
""" | |
class ImageEditDeps: | |
edit_instruction: str | |
hopter_client: Hopter | |
mask_service: GenerateMaskService | |
image_url: Optional[str] = None | |
model = OpenAIModel( | |
"gpt-4o", | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
class EditImageResult: | |
edited_image_url: str | |
image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps) | |
def upload_image_from_base64(base64_image: str) -> str: | |
image_format = base64_image.split(",")[0] | |
image_data = base64.b64decode(base64_image.split(",")[1]) | |
suffix = ".jpg" if image_format == "image/jpeg" else ".png" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
temp_filename = temp_file.name | |
with open(temp_filename, "wb") as f: | |
f.write(image_data) | |
return upload_image(temp_filename) | |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
""" | |
Use this tool to edit an object in the image. for example: | |
- remove the pole | |
- replace the dog with a cat | |
- change the background to a beach | |
- remove the person in the image | |
- change the hair color to red | |
- change the hat to a cap | |
""" | |
edit_instruction = ctx.deps.edit_instruction | |
image_url = ctx.deps.image_url | |
mask_service = ctx.deps.mask_service | |
hopter_client = ctx.deps.hopter_client | |
image_uri = download_image_to_data_uri(image_url) | |
# Generate mask | |
mask_instruction = mask_service.get_mask_generation_instruction( | |
edit_instruction, image_url | |
) | |
mask = mask_service.generate_mask(mask_instruction, image_uri) | |
# Magic replace | |
input = MagicReplaceInput( | |
image=image_uri, mask=mask, prompt=mask_instruction.target_caption | |
) | |
result = hopter_client.magic_replace(input) | |
uploaded_image = upload_image_from_base64(result.base64_image) | |
return EditImageResult(edited_image_url=uploaded_image) | |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult: | |
""" | |
run super resolution, upscale, or enhance the image | |
""" | |
image_url = ctx.deps.image_url | |
hopter_client = ctx.deps.hopter_client | |
image_uri = download_image_to_data_uri(image_url) | |
input = SuperResolutionInput( | |
image_b64=image_uri, scale=4, use_face_enhancement=False | |
) | |
result = hopter_client.super_resolution(input) | |
uploaded_image = upload_image_from_base64(result.scaled_image) | |
return EditImageResult(edited_image_url=uploaded_image) | |
async def main(): | |
image_file_path = "./assets/lakeview.jpg" | |
image_url = image_path_to_uri(image_file_path) | |
prompt = "remove the light post" | |
messages = [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": image_url}}, | |
] | |
# Initialize services | |
hopter = Hopter( | |
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING | |
) | |
mask_service = GenerateMaskService(hopter=hopter) | |
# Initialize dependencies | |
deps = ImageEditDeps( | |
edit_instruction=prompt, | |
image_url=image_url, | |
hopter_client=hopter, | |
mask_service=mask_service, | |
) | |
async with image_edit_agent.run_stream(messages, deps=deps) as result: | |
async for message in result.stream(): | |
print(message) | |
if __name__ == "__main__": | |
asyncio.run(main()) | |