File size: 4,477 Bytes
a21dee1
 
 
 
 
 
6962136
a21dee1
 
7daa838
c55fe6a
ad3aed5
 
 
 
a21dee1
 
 
 
 
 
 
 
 
6962136
a21dee1
 
 
 
 
9e822e4
 
6962136
a21dee1
 
 
 
 
 
 
 
 
 
9e822e4
 
 
c55fe6a
9e822e4
a21dee1
 
 
 
 
 
ad3aed5
 
 
 
 
 
 
 
 
 
a21dee1
9e822e4
a21dee1
6962136
 
 
 
 
 
 
a21dee1
9e822e4
 
 
 
 
ad3aed5
 
9e822e4
c55fe6a
ad3aed5
a21dee1
9e822e4
ad3aed5
c55fe6a
ad3aed5
 
7daa838
 
 
 
 
 
 
 
 
9e6acd9
 
 
7daa838
ad3aed5
 
a21dee1
 
 
c55fe6a
 
 
 
 
 
 
 
 
 
 
 
a21dee1
c55fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a21dee1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from dataclasses import dataclass
from typing import Optional
import logfire
from src.services.generate_mask import GenerateMaskService
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
from src.services.image_uploader import ImageUploader
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
import base64
import tempfile
from PIL import Image

load_dotenv()

logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
logfire.instrument_openai()

system_prompt = """
I will give you an editing instruction of the image. 
if the edit instruction involved modifying parts of the image, please generate a mask for it.
if images are not provided, ask the user to provide an image.
"""

@dataclass
class ImageEditDeps:
    edit_instruction: str
    hopter_client: Hopter
    mask_service: GenerateMaskService
    image_url: Optional[str] = None

model = OpenAIModel(
    "gpt-4o",
    api_key=os.environ.get("OPENAI_API_KEY"),
)

@dataclass
class MaskGenerationResult:
    mask_image_base64: str


@dataclass
class EditImageResult:
    edited_image_url: str

mask_generation_agent = Agent(
    model,
    system_prompt=system_prompt,
    deps_type=ImageEditDeps
)

def upload_image_from_base64(base64_image: str) -> str:
    image_format = base64_image.split(",")[0]
    image_data = base64.b64decode(base64_image.split(",")[1])
    suffix = ".jpg" if image_format == "image/jpeg" else ".png"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
        temp_filename = temp_file.name
        with open(temp_filename, "wb") as f:
            f.write(image_data)
    return upload_image(temp_filename)

@mask_generation_agent.tool
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
    """
    Use this tool to edit an object in the image. for example:
    - remove the pole
    - replace the dog with a cat
    - change the background to a beach
    - remove the person in the image
    - change the hair color to red
    - change the hat to a cap
    """
    edit_instruction = ctx.deps.edit_instruction
    image_url = ctx.deps.image_url
    mask_service = ctx.deps.mask_service
    hopter_client = ctx.deps.hopter_client

    image_uri = download_image_to_data_uri(image_url)

    # Generate mask
    mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
    mask = mask_service.generate_mask(mask_instruction, image_uri)

    # Magic replace
    input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
    result = hopter_client.magic_replace(input)
    uploaded_image = upload_image_from_base64(result.base64_image)
    return EditImageResult(edited_image_url=uploaded_image)

@mask_generation_agent.tool
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
    """
    run super resolution, upscale, or enhance the image
    """
    image_url = ctx.deps.image_url
    hopter_client = ctx.deps.hopter_client

    image_uri = download_image_to_data_uri(image_url)

    input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
    result = hopter_client.super_resolution(input)
    uploaded_image = upload_image_from_base64(result.scaled_image)
    return EditImageResult(edited_image_url=uploaded_image)

async def main():
    image_file_path = "./assets/lakeview.jpg"
    image_url = image_path_to_uri(image_file_path)

    prompt = "remove the light post"
    messages = [
        {
            "type": "text",
            "text": prompt
        },
        {
            "type": "image_url",
            "image_url": {
                "url": image_url
            }
        }
    ]

    # Initialize services
    hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
    mask_service = GenerateMaskService(hopter=hopter)

    # Initialize dependencies
    deps = ImageEditDeps(
        edit_instruction=prompt,
        image_url=image_url,
        hopter_client=hopter,
        mask_service=mask_service
    )
    async with mask_generation_agent.run_stream(
        messages, 
        deps=deps
    ) as result:
        async for message in result.stream():
            print(message)


if __name__ == "__main__":
    asyncio.run(main())