Spaces:
Running
Running
File size: 4,477 Bytes
a21dee1 6962136 a21dee1 7daa838 c55fe6a ad3aed5 a21dee1 6962136 a21dee1 9e822e4 6962136 a21dee1 9e822e4 c55fe6a 9e822e4 a21dee1 ad3aed5 a21dee1 9e822e4 a21dee1 6962136 a21dee1 9e822e4 ad3aed5 9e822e4 c55fe6a ad3aed5 a21dee1 9e822e4 ad3aed5 c55fe6a ad3aed5 7daa838 9e6acd9 7daa838 ad3aed5 a21dee1 c55fe6a a21dee1 c55fe6a a21dee1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from dataclasses import dataclass
from typing import Optional
import logfire
from src.services.generate_mask import GenerateMaskService
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
from src.services.image_uploader import ImageUploader
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
import base64
import tempfile
from PIL import Image
load_dotenv()
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
logfire.instrument_openai()
system_prompt = """
I will give you an editing instruction of the image.
if the edit instruction involved modifying parts of the image, please generate a mask for it.
if images are not provided, ask the user to provide an image.
"""
@dataclass
class ImageEditDeps:
edit_instruction: str
hopter_client: Hopter
mask_service: GenerateMaskService
image_url: Optional[str] = None
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
@dataclass
class MaskGenerationResult:
mask_image_base64: str
@dataclass
class EditImageResult:
edited_image_url: str
mask_generation_agent = Agent(
model,
system_prompt=system_prompt,
deps_type=ImageEditDeps
)
def upload_image_from_base64(base64_image: str) -> str:
image_format = base64_image.split(",")[0]
image_data = base64.b64decode(base64_image.split(",")[1])
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_filename = temp_file.name
with open(temp_filename, "wb") as f:
f.write(image_data)
return upload_image(temp_filename)
@mask_generation_agent.tool
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
Use this tool to edit an object in the image. for example:
- remove the pole
- replace the dog with a cat
- change the background to a beach
- remove the person in the image
- change the hair color to red
- change the hat to a cap
"""
edit_instruction = ctx.deps.edit_instruction
image_url = ctx.deps.image_url
mask_service = ctx.deps.mask_service
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
# Generate mask
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
mask = mask_service.generate_mask(mask_instruction, image_uri)
# Magic replace
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
result = hopter_client.magic_replace(input)
uploaded_image = upload_image_from_base64(result.base64_image)
return EditImageResult(edited_image_url=uploaded_image)
@mask_generation_agent.tool
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
"""
run super resolution, upscale, or enhance the image
"""
image_url = ctx.deps.image_url
hopter_client = ctx.deps.hopter_client
image_uri = download_image_to_data_uri(image_url)
input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
result = hopter_client.super_resolution(input)
uploaded_image = upload_image_from_base64(result.scaled_image)
return EditImageResult(edited_image_url=uploaded_image)
async def main():
image_file_path = "./assets/lakeview.jpg"
image_url = image_path_to_uri(image_file_path)
prompt = "remove the light post"
messages = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# Initialize services
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
mask_service = GenerateMaskService(hopter=hopter)
# Initialize dependencies
deps = ImageEditDeps(
edit_instruction=prompt,
image_url=image_url,
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
async for message in result.stream():
print(message)
if __name__ == "__main__":
asyncio.run(main()) |