Spaces:
Running
Running
File size: 4,266 Bytes
583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 583b7ad fcb8f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from dataclasses import dataclass
from typing import Optional
import logfire
from src.services.generate_mask import GenerateMaskService
from src.hopter.client import (
Hopter,
Environment,
MagicReplaceInput,
SuperResolutionInput,
)
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
import base64
import tempfile
load_dotenv()
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
logfire.instrument_openai()
system_prompt = """
I will give you an editing instruction of the image.
if the edit instruction involved modifying parts of the image, please generate a mask for it.
if images are not provided, ask the user to provide an image.
"""
@dataclass
class ImageEditDeps:
edit_instruction: str
hopter_client: Hopter
mask_service: GenerateMaskService
image_url: Optional[str] = None
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
@dataclass
class EditImageResult:
edited_image_url: str
image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps)
def upload_image_from_base64(base64_image: str) -> str:
image_format = base64_image.split(",")[0]
image_data = base64.b64decode(base64_image.split(",")[1])
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_filename = temp_file.name
with open(temp_filename, "wb") as f:
f.write(image_data)
return upload_image(temp_filename)
@image_edit_agent.tool
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
Use this tool to edit an object in the image. for example:
- remove the pole
- replace the dog with a cat
- change the background to a beach
- remove the person in the image
- change the hair color to red
- change the hat to a cap
"""
edit_instruction = ctx.deps.edit_instruction
image_url = ctx.deps.image_url
mask_service = ctx.deps.mask_service
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
# Generate mask
mask_instruction = mask_service.get_mask_generation_instruction(
edit_instruction, image_url
)
mask = mask_service.generate_mask(mask_instruction, image_uri)
# Magic replace
input = MagicReplaceInput(
image=image_uri, mask=mask, prompt=mask_instruction.target_caption
)
result = hopter_client.magic_replace(input)
uploaded_image = upload_image_from_base64(result.base64_image)
return EditImageResult(edited_image_url=uploaded_image)
@image_edit_agent.tool
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
run super resolution, upscale, or enhance the image
"""
image_url = ctx.deps.image_url
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
input = SuperResolutionInput(
image_b64=image_uri, scale=4, use_face_enhancement=False
)
result = hopter_client.super_resolution(input)
uploaded_image = upload_image_from_base64(result.scaled_image)
return EditImageResult(edited_image_url=uploaded_image)
async def main():
image_file_path = "./assets/lakeview.jpg"
image_url = image_path_to_uri(image_file_path)
prompt = "remove the light post"
messages = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
]
# Initialize services
hopter = Hopter(
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
# Initialize dependencies
deps = ImageEditDeps(
edit_instruction=prompt,
image_url=image_url,
hopter_client=hopter,
mask_service=mask_service,
)
async with image_edit_agent.run_stream(messages, deps=deps) as result:
async for message in result.stream():
print(message)
if __name__ == "__main__":
asyncio.run(main())
|