chat-image-edit / gradio_chat.py
simonlee-cb's picture
feat: added examples
9e6acd9
raw
history blame
8.61 kB
import gradio as gr
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
import os
from src.hopter.client import Hopter, Environment
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import image_path_to_uri, upload_image
from pydantic_ai.messages import (
ToolCallPart,
ToolReturnPart
)
from src.agents.mask_generation_agent import EditImageResult
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
simple_agent = Agent(
model,
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
deps_type=ImageEditDeps
)
load_dotenv()
def build_user_message(chat_input):
text = chat_input["text"]
images = chat_input["files"]
messages = [
{
"role": "user",
"content": text
}
]
if images:
messages.extend([
{
"role": "user",
"content": {"path": image}
}
for image in images
])
return messages
def build_messages_for_agent(chat_input, past_messages):
# filter out image messages from past messages to save on tokens
messages = past_messages
# add the user's text message
if chat_input["text"]:
messages.append({
"type": "text",
"text": chat_input["text"]
})
# add the user's image message
files = chat_input.get("files", [])
image_url = upload_image(files[0]) if files else None
if image_url:
messages.append({
"type": "image_url",
"image_url": {"url": image_url}
})
return messages
def select_example(x: gr.SelectData, chat_input):
chat_input["text"] = x.value["text"]
chat_input["files"] = x.value["files"]
return chat_input
async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
# Prepare messages for the UI
chatbot.extend(build_user_message(chat_input))
yield {"text": "", "files": []}, chatbot, gr.skip, gr.skip()
# Prepare messages for the agent
text = chat_input["text"]
files = chat_input.get("files", [])
image_url = upload_image(files[0]) if files else None
messages = [
{
"type": "text",
"text": text
},
]
if image_url:
messages.append(
{"type": "image_url", "image_url": {"url": image_url}}
)
current_image = image_url
# Dependencies
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
mask_service = GenerateMaskService(hopter=hopter)
deps = ImageEditDeps(
edit_instruction=text,
image_url=current_image,
hopter_client=hopter,
mask_service=mask_service
)
# Run the agent
async with mask_generation_agent.run_stream(
messages,
deps=deps,
message_history=past_messages
) as result:
for message in result.new_messages():
for call in message.parts:
if isinstance(call, ToolCallPart):
call_args = (
call.args.args_json
if hasattr(call.args, 'args_json')
else call.args
)
metadata = {
'title': f'🛠️ Using {call.tool_name}',
}
# set the tool call id so that when the tool returns
# we can find this message and update with the result
if call.tool_call_id is not None:
metadata['id'] = call.tool_call_id
# Create a tool call message to show on the UI
gr_message = {
'role': 'assistant',
'content': 'Parameters: ' + call_args,
'metadata': metadata,
}
chatbot.append(gr_message)
if isinstance(call, ToolReturnPart):
for gr_message in chatbot:
# Skip messages without metadata
if not gr_message.get('metadata'):
continue
if gr_message['metadata'].get('id', '') == call.tool_call_id:
if isinstance(call.content, EditImageResult):
chatbot.append({
"role": "assistant",
"content": gr.Image(call.content.edited_image_url),
"files": [call.content.edited_image_url]
})
else:
gr_message['content'] += (
f'\nOutput: {call.content}'
)
yield gr.skip(), chatbot, gr.skip(), gr.skip()
chatbot.append({'role': 'assistant', 'content': ''})
async for message in result.stream_text():
chatbot[-1]['content'] = message
yield gr.skip(), chatbot, gr.skip(), gr.skip()
past_messages = result.all_messages()
yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
<div>
<h1 style="margin: 0 0 1rem 0">Image Editing Assistant</h1>
<h3 style="margin: 0 0 0.5rem 0">
This assistant edits images according to your instructions.
</h3>
</div>
</div>
"""
)
current_image = gr.State(None)
past_messages = gr.State([])
chatbot = gr.Chatbot(
elem_id="chatbot",
label='Image Editing Assistant',
type='messages',
avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'),
examples=[
{
"text": "Remove the person in the image",
"files": [
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
]
},
{
"text": "Change all the balloons to red in the image",
"files": [
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
]
},
{
"text": "Change coffee to a glass of water",
"files": [
"https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
]
},
{
"text": "ENHANCE!",
"files": [
"https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
]
}
]
)
with gr.Row():
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="single",
show_label=False,
placeholder='How would you like to edit this image?',
sources=["upload"]
)
generation = chat_input.submit(
stream_from_agent,
inputs=[chat_input, chatbot, past_messages, current_image],
outputs=[chat_input, chatbot, past_messages, current_image],
)
chatbot.example_select(
select_example,
inputs=[chat_input],
outputs=[chat_input],
).then(
stream_from_agent,
inputs=[chat_input, chatbot, past_messages, current_image],
outputs=[chat_input, chatbot, past_messages, current_image],
)
if __name__ == '__main__':
demo.launch()