Spaces:
Running
Running
from pydantic_ai import Agent, RunContext | |
from pydantic_ai.settings import ModelSettings | |
from pydantic_ai.models.openai import OpenAIModel | |
from dotenv import load_dotenv | |
import os | |
import asyncio | |
from src.utils import image_path_to_base64 | |
from dataclasses import dataclass | |
load_dotenv() | |
class ImageEditDeps: | |
edit_instruction: str | |
image_url: str | |
model = OpenAIModel( | |
"gpt-4o", | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
image_edit_agent = Agent( | |
model, | |
system_prompt=[ | |
'Be concise, reply with one sentence.', | |
"You are an image editing agent. You will be given an image and an editing instruction. Use the tools available to you and come up with a plan to edit the image according to the instruction." | |
], | |
deps_type=ImageEditDeps | |
) | |
async def identify_editing_subject(ctx: RunContext[ImageEditDeps]) -> str: | |
""" | |
Identify the subject of the image editing instruction. | |
Args: | |
instruction: The image editing instruction. | |
image_url: The URL of the image. | |
Returns: | |
The subject of the image editing instruction. | |
""" | |
messages = [ | |
{ | |
"type": "text", | |
"text": ctx.deps.edit_instruction | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": ctx.deps.image_url | |
} | |
} | |
] | |
r = await mask_generation_agent.run(messages, usage=ctx.usage, deps=ctx.deps) | |
return r.data | |
mask_generation_agent = Agent( | |
model, | |
system_prompt=[ | |
"I will give you an editing instruction of the image. Please output the object needed to be edited.", | |
"You only need to output the basic description of the object in no more than 5 words.", | |
"The output should only contain one noun.", | |
"For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else." | |
], | |
deps_type=ImageEditDeps | |
) | |
async def generate_mask(ctx: RunContext[ImageEditDeps], mask_subject: str) -> str: | |
""" | |
Generate a mask for the image editing instruction. | |
""" | |
pass | |
async def main(): | |
image_file_path = "./assets/lakeview.jpg" | |
image_base64 = image_path_to_base64(image_file_path) | |
image_url = f"data:image/jpeg;base64,{image_base64}" | |
prompt = "remove the light post" | |
messages = [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": image_url | |
} | |
} | |
] | |
deps = ImageEditDeps( | |
edit_instruction=prompt, | |
image_url=image_url | |
) | |
r = await mask_generation_agent.run(messages, deps=deps) | |
print(r.data) | |
if __name__ == "__main__": | |
asyncio.run(main()) | |