chat-image-edit / gradio_app.py
simonlee-cb's picture
feat: working gradio demo
c55fe6a
raw
history blame
5.42 kB
import gradio as gr
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
import os
from src.hopter.client import Hopter, Environment
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import image_path_to_uri
from pydantic_ai.messages import (
ToolCallPart,
ToolReturnPart
)
from src.agents.mask_generation_agent import EditImageResult
from pydantic_ai.agent import Agent
from pydantic_ai.models.openai import OpenAIModel
model = OpenAIModel(
"gpt-4o",
api_key=os.environ.get("OPENAI_API_KEY"),
)
simple_agent = Agent(
model,
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
deps_type=ImageEditDeps
)
load_dotenv()
def build_user_message(chat_input):
text = chat_input["text"]
images = chat_input["files"]
messages = [
{
"role": "user",
"content": text
}
]
if images:
messages.extend([
{
"role": "user",
"content": {"path": image}
}
for image in images
])
return messages
async def stream_from_agent(chat_input, chatbot, past_messages):
chatbot.extend(build_user_message(chat_input))
# Clear the input immediately after submission
yield {"text": "", "files": []}, chatbot, gr.skip
# for agent
text = chat_input["text"]
images = [image_path_to_uri(image) for image in chat_input["files"]]
messages = [
{
"type": "text",
"text": text
},
]
if images:
messages.extend([
{"type": "image_url", "image_url": {"url": image}}
for image in images
])
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
mask_service = GenerateMaskService(hopter=hopter)
deps = ImageEditDeps(
edit_instruction=text,
image_url=images[0],
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
for message in result.new_messages():
for call in message.parts:
if isinstance(call, ToolCallPart):
call_args = (
call.args.args_json
if hasattr(call.args, 'args_json')
else call.args
)
metadata = {
'title': f'🛠️ Using {call.tool_name}',
}
if call.tool_call_id is not None:
metadata['id'] = call.tool_call_id
gr_message = {
'role': 'assistant',
'content': 'Parameters: ' + call_args,
'metadata': metadata,
}
chatbot.append(gr_message)
if isinstance(call, ToolReturnPart):
for gr_message in chatbot:
if (
gr_message.get('metadata', {}).get('id', '')
== call.tool_call_id
):
if isinstance(call.content, EditImageResult):
chatbot.append({
"role": "assistant",
"content": gr.Image(call.content.edited_image_url),
"files": [call.content.edited_image_url]
})
else:
gr_message['content'] += (
f'\nOutput: {call.content}'
)
yield gr.skip(), chatbot, gr.skip()
chatbot.append({'role': 'assistant', 'content': ''})
async for message in result.stream_text():
chatbot[-1]['content'] = message
yield gr.skip(), chatbot, gr.skip()
past_messages = result.all_messages()
yield gr.Textbox(interactive=True), gr.skip(), past_messages
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
<div>
<h1 style="margin: 0 0 1rem 0">Image Editing Assistant</h1>
<h3 style="margin: 0 0 0.5rem 0">
This assistant edits images according to your instructions.
</h3>
</div>
</div>
"""
)
past_messages = gr.State([])
chatbot = gr.Chatbot(
label='Image Editing Assistant',
type='messages',
avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'),
)
with gr.Row():
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="multiple",
show_label=False,
placeholder='How would you like to edit this image?',
sources=["upload", "microphone"]
)
generation = chat_input.submit(
stream_from_agent,
inputs=[chat_input, chatbot, past_messages],
outputs=[chat_input, chatbot, past_messages],
)
if __name__ == '__main__':
demo.launch()