File size: 4,266 Bytes
583b7ad
 
 
 
 
 
 
 
 
fcb8f25
 
 
 
 
 
583b7ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcb8f25
583b7ad
 
 
 
 
 
 
fcb8f25
583b7ad
 
 
 
 
 
 
 
 
 
fcb8f25
 
 
583b7ad
 
 
 
 
 
 
 
 
 
 
fcb8f25
583b7ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcb8f25
 
 
583b7ad
 
 
fcb8f25
 
 
583b7ad
 
 
 
fcb8f25
583b7ad
 
 
 
 
 
 
 
 
 
fcb8f25
 
 
583b7ad
 
 
 
fcb8f25
583b7ad
 
 
 
 
 
fcb8f25
 
583b7ad
 
 
fcb8f25
 
 
583b7ad
 
 
 
 
 
 
fcb8f25
583b7ad
fcb8f25
583b7ad
 
 
 
 
fcb8f25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from dataclasses import dataclass
from typing import Optional
import logfire
from src.services.generate_mask import GenerateMaskService
from src.hopter.client import (
    Hopter,
    Environment,
    MagicReplaceInput,
    SuperResolutionInput,
)
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
import base64
import tempfile

load_dotenv()

logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
logfire.instrument_openai()

system_prompt = """
I will give you an editing instruction of the image. 
if the edit instruction involved modifying parts of the image, please generate a mask for it.
if images are not provided, ask the user to provide an image.
"""


@dataclass
class ImageEditDeps:
    edit_instruction: str
    hopter_client: Hopter
    mask_service: GenerateMaskService
    image_url: Optional[str] = None


model = OpenAIModel(
    "gpt-4o",
    api_key=os.environ.get("OPENAI_API_KEY"),
)


@dataclass
class EditImageResult:
    edited_image_url: str


image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps)


def upload_image_from_base64(base64_image: str) -> str:
    image_format = base64_image.split(",")[0]
    image_data = base64.b64decode(base64_image.split(",")[1])
    suffix = ".jpg" if image_format == "image/jpeg" else ".png"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
        temp_filename = temp_file.name
        with open(temp_filename, "wb") as f:
            f.write(image_data)
    return upload_image(temp_filename)


@image_edit_agent.tool
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
    """
    Use this tool to edit an object in the image. for example:
    - remove the pole
    - replace the dog with a cat
    - change the background to a beach
    - remove the person in the image
    - change the hair color to red
    - change the hat to a cap
    """
    edit_instruction = ctx.deps.edit_instruction
    image_url = ctx.deps.image_url
    mask_service = ctx.deps.mask_service
    hopter_client = ctx.deps.hopter_client

    image_uri = download_image_to_data_uri(image_url)

    # Generate mask
    mask_instruction = mask_service.get_mask_generation_instruction(
        edit_instruction, image_url
    )
    mask = mask_service.generate_mask(mask_instruction, image_uri)

    # Magic replace
    input = MagicReplaceInput(
        image=image_uri, mask=mask, prompt=mask_instruction.target_caption
    )
    result = hopter_client.magic_replace(input)
    uploaded_image = upload_image_from_base64(result.base64_image)
    return EditImageResult(edited_image_url=uploaded_image)


@image_edit_agent.tool
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
    """
    run super resolution, upscale, or enhance the image
    """
    image_url = ctx.deps.image_url
    hopter_client = ctx.deps.hopter_client

    image_uri = download_image_to_data_uri(image_url)

    input = SuperResolutionInput(
        image_b64=image_uri, scale=4, use_face_enhancement=False
    )
    result = hopter_client.super_resolution(input)
    uploaded_image = upload_image_from_base64(result.scaled_image)
    return EditImageResult(edited_image_url=uploaded_image)


async def main():
    image_file_path = "./assets/lakeview.jpg"
    image_url = image_path_to_uri(image_file_path)

    prompt = "remove the light post"
    messages = [
        {"type": "text", "text": prompt},
        {"type": "image_url", "image_url": {"url": image_url}},
    ]

    # Initialize services
    hopter = Hopter(
        api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
    )
    mask_service = GenerateMaskService(hopter=hopter)

    # Initialize dependencies
    deps = ImageEditDeps(
        edit_instruction=prompt,
        image_url=image_url,
        hopter_client=hopter,
        mask_service=mask_service,
    )
    async with image_edit_agent.run_stream(messages, deps=deps) as result:
        async for message in result.stream():
            print(message)


if __name__ == "__main__":
    asyncio.run(main())