chat-image-edit / image_edit_chat.py
simonlee-cb's picture
refactor: formatting
fcb8f25
import gradio as gr
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
import os
from src.hopter.client import Hopter, Environment
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import upload_image
from pydantic_ai.messages import ToolCallPart, ToolReturnPart
from pydantic_ai.models.openai import OpenAIModel
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
INTRO = """
# Image Editing Assistant
### Experience seamless image editing using natural language in a chat-based interface.
With this demo, you can:
- Enhance or upscale an image
- Remove objects from an image
- Replace elements within an image
- Change the background
"""
EXAMPLES = [
{
"text": "Replace the background to the space with stars and planets",
"files": [
"https://cdn.prod.website-files.com/66f230993926deadc0ac3a44/66f370d65f158cbbcfbcc532_Crossed%20Arms%20Levi%20Meir%20Clancy.jpg"
],
},
{
"text": "Change all the balloons to red in the image",
"files": [
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
],
},
{
"text": "Change coffee to a glass of water",
"files": [
"https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
],
},
{
"text": "ENHANCE!",
"files": [
"https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
],
},
]
load_dotenv()
def build_user_message(chat_input):
text = chat_input["text"]
images = chat_input["files"]
messages = [{"role": "user", "content": text}]
if images:
messages.extend(
[{"role": "user", "content": {"path": image}} for image in images]
)
return messages
def build_messages_for_agent(chat_input, past_messages):
# filter out image messages from past messages to save on tokens
messages = past_messages
# add the user's text message
if chat_input["text"]:
messages.append({"type": "text", "text": chat_input["text"]})
# add the user's image message
files = chat_input.get("files", [])
image_url = upload_image(files[0]) if files else None
if image_url:
messages.append({"type": "image_url", "image_url": {"url": image_url}})
return messages
def select_example(x: gr.SelectData, chat_input):
chat_input["text"] = x.value["text"]
chat_input["files"] = x.value["files"]
return chat_input
async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
# Prepare messages for the UI
chatbot.extend(build_user_message(chat_input))
yield {"text": "", "files": []}, chatbot, gr.skip, gr.skip()
# Prepare messages for the agent
text = chat_input["text"]
files = chat_input.get("files", [])
image_url = upload_image(files[0]) if files else None
messages = [
{"type": "text", "text": text},
]
if image_url:
messages.append({"type": "image_url", "image_url": {"url": image_url}})
current_image = image_url
# Dependencies
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
mask_service = GenerateMaskService(hopter=hopter)
deps = ImageEditDeps(
edit_instruction=text,
image_url=current_image,
hopter_client=hopter,
mask_service=mask_service,
)
# Run the agent
async with image_edit_agent.run_stream(
messages, deps=deps, message_history=past_messages
) as result:
for message in result.new_messages():
for call in message.parts:
if isinstance(call, ToolCallPart):
call_args = (
call.args.args_json
if hasattr(call.args, "args_json")
else call.args
)
metadata = {
"title": f"🛠️ Using {call.tool_name}",
}
# set the tool call id so that when the tool returns
# we can find this message and update with the result
if call.tool_call_id is not None:
metadata["id"] = call.tool_call_id
# Create a tool call message to show on the UI
gr_message = {
"role": "assistant",
"content": "Parameters: " + call_args,
"metadata": metadata,
}
chatbot.append(gr_message)
if isinstance(call, ToolReturnPart):
for gr_message in chatbot:
# Skip messages without metadata
if not gr_message.get("metadata"):
continue
if gr_message["metadata"].get("id", "") == call.tool_call_id:
if isinstance(call.content, EditImageResult):
chatbot.append(
{
"role": "assistant",
"content": gr.Image(
call.content.edited_image_url
),
"files": [call.content.edited_image_url],
}
)
current_image = call.content.edited_image_url
else:
gr_message["content"] += f"\nOutput: {call.content}"
yield gr.skip(), chatbot, gr.skip(), gr.skip()
chatbot.append({"role": "assistant", "content": ""})
async for message in result.stream_text():
chatbot[-1]["content"] = message
yield gr.skip(), chatbot, gr.skip(), gr.skip()
past_messages = result.all_messages()
yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
with gr.Blocks() as demo:
gr.Markdown(INTRO)
current_image = gr.State(None)
past_messages = gr.State([])
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Image Editing Assistant",
type="messages",
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
examples=EXAMPLES,
)
with gr.Row():
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="single",
show_label=False,
placeholder="How would you like to edit this image?",
sources=["upload"],
)
generation = chat_input.submit(
stream_from_agent,
inputs=[chat_input, chatbot, past_messages, current_image],
outputs=[chat_input, chatbot, past_messages, current_image],
)
chatbot.example_select(
select_example,
inputs=[chat_input],
outputs=[chat_input],
).then(
stream_from_agent,
inputs=[chat_input, chatbot, past_messages, current_image],
outputs=[chat_input, chatbot, past_messages, current_image],
)
examples = gr.Examples(
examples=EXAMPLES,
inputs=[chat_input],
outputs=[chat_input],
)
if __name__ == "__main__":
demo.launch()