chat-image-edit / src /agents /image-edit-agent.py
simonlee-cb's picture
feat: working gradio demo
c55fe6a
raw
history blame
2.9 kB
from pydantic_ai import Agent, RunContext
from pydantic_ai.settings import ModelSettings
from pydantic_ai.models.openai import OpenAIModel
from dotenv import load_dotenv
import os
import asyncio
from src.utils import image_path_to_base64
from dataclasses import dataclass
load_dotenv()
@dataclass
class ImageEditDeps:
edit_instruction: str
image_url: str
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
image_edit_agent = Agent(
model,
system_prompt=[
'Be concise, reply with one sentence.',
"You are an image editing agent. You will be given an image and an editing instruction. Use the tools available to you and come up with a plan to edit the image according to the instruction."
],
deps_type=ImageEditDeps
)
@image_edit_agent.tool
async def identify_editing_subject(ctx: RunContext[ImageEditDeps]) -> str:
"""
Identify the subject of the image editing instruction.
Args:
instruction: The image editing instruction.
image_url: The URL of the image.
Returns:
The subject of the image editing instruction.
"""
messages = [
{
"type": "text",
"text": ctx.deps.edit_instruction
},
{
"type": "image_url",
"image_url": {
"url": ctx.deps.image_url
}
}
]
r = await mask_generation_agent.run(messages, usage=ctx.usage, deps=ctx.deps)
return r.data
mask_generation_agent = Agent(
model,
system_prompt=[
"I will give you an editing instruction of the image. Please output the object needed to be edited.",
"You only need to output the basic description of the object in no more than 5 words.",
"The output should only contain one noun.",
"For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
],
deps_type=ImageEditDeps
)
@mask_generation_agent.tool
async def generate_mask(ctx: RunContext[ImageEditDeps], mask_subject: str) -> str:
"""
Generate a mask for the image editing instruction.
"""
pass
async def main():
image_file_path = "./assets/lakeview.jpg"
image_base64 = image_path_to_base64(image_file_path)
image_url = f"data:image/jpeg;base64,{image_base64}"
prompt = "remove the light post"
messages = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
deps = ImageEditDeps(
edit_instruction=prompt,
image_url=image_url
)
r = await mask_generation_agent.run(messages, deps=deps)
print(r.data)
if __name__ == "__main__":
asyncio.run(main())