File size: 2,897 Bytes
a21dee1
 
 
 
 
 
c55fe6a
a21dee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55fe6a
 
a21dee1
c55fe6a
 
 
 
 
 
 
 
 
 
a21dee1
c55fe6a
 
a21dee1
c55fe6a
 
 
 
 
 
a21dee1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from pydantic_ai import Agent, RunContext
from pydantic_ai.settings import ModelSettings
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from src.utils import image_path_to_base64
from dataclasses import dataclass

load_dotenv()

@dataclass
class ImageEditDeps:
    edit_instruction: str
    image_url: str

model = OpenAIModel(
    "gpt-4o",
    api_key=os.environ.get("OPENAI_API_KEY"),
)

image_edit_agent = Agent(
    model,
    system_prompt=[
        'Be concise, reply with one sentence.',
        "You are an image editing agent. You will be given an image and an editing instruction. Use the tools available to you and come up with a plan to edit the image according to the instruction."
    ],
    deps_type=ImageEditDeps
)


@image_edit_agent.tool
async def identify_editing_subject(ctx: RunContext[ImageEditDeps]) -> str:
    """
    Identify the subject of the image editing instruction.

    Args:
        instruction: The image editing instruction.
        image_url: The URL of the image.

    Returns:
        The subject of the image editing instruction.
    """
    messages = [
        {
            "type": "text",
            "text": ctx.deps.edit_instruction
        },
        {
            "type": "image_url",
            "image_url": {
                "url": ctx.deps.image_url
            }
        }
    ]
    r = await mask_generation_agent.run(messages, usage=ctx.usage, deps=ctx.deps)
    return r.data

mask_generation_agent = Agent(
    model,
    system_prompt=[
        "I will give you an editing instruction of the image. Please output the object needed to be edited.",
        "You only need to output the basic description of the object in no more than 5 words.",
        "The output should only contain one noun.",
        "For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
    ],
    deps_type=ImageEditDeps
)

@mask_generation_agent.tool
async def generate_mask(ctx: RunContext[ImageEditDeps], mask_subject: str) -> str:
    """
    Generate a mask for the image editing instruction.
    """
    pass

async def main():
    image_file_path = "./assets/lakeview.jpg"
    image_base64 = image_path_to_base64(image_file_path)
    image_url = f"data:image/jpeg;base64,{image_base64}"

    prompt = "remove the light post"
    messages = [
        {
            "type": "text",
            "text": prompt
        },
        {
            "type": "image_url",
            "image_url": {
                "url": image_url
            }
        }
    ]

    deps = ImageEditDeps(
        edit_instruction=prompt,
        image_url=image_url
    )
    r = await mask_generation_agent.run(messages, deps=deps)
    print(r.data)


if __name__ == "__main__":
    asyncio.run(main())