chat-image-edit / agent.py
simonlee-cb's picture
feat: working gradio demo
c55fe6a
raw
history blame
5.15 kB
import streamlit as st
from typing import TypedDict, Literal
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
UserPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
)
import asyncio
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
from src.hopter.client import Hopter, Environment
import os
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import image_path_to_uri
load_dotenv()
st.set_page_config(
page_title="Conversational Image Editor",
page_icon="🧊",
layout="wide",
initial_sidebar_state="collapsed"
)
hopter = Hopter(
api_key=os.getenv("HOPTER_API_KEY"),
environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
user_msg_input_key = "input_user_msg"
class ChatMessage(TypedDict):
"""Format of messages sent to the browser/API."""
role: Literal['user', 'model']
timestamp: str
content: str
def display_message_part(part):
"""
Display a single part of a message in the Streamlit UI.
Customize how you display system prompts, user prompts,
tool calls, tool returns, etc.
"""
# system-prompt
if part.part_kind == 'system-prompt':
with st.chat_message("system"):
st.markdown(f"**System**: {part.content}")
# user-prompt
elif part.part_kind == 'user-prompt':
with st.chat_message("user"):
st.markdown(part.content)
# text
elif part.part_kind == 'text':
with st.chat_message("assistant"):
st.markdown(part.content)
# tool call
elif part.part_kind == 'tool-call':
with st.chat_message("assistant"):
st.markdown(f"**{part.tool_name}**: {part.args}")
# tool return
elif part.part_kind == 'tool-return':
with st.chat_message("assistant"):
st.markdown(f"**{part.tool_name}**: {part.content}")
async def run_agent(user_input: str, image_b64: str):
messages = [
{
"type": "text",
"text": user_input
},
{
"type": "image_url",
"image_url": {
"url": image_b64
}
}
]
deps = ImageEditDeps(
edit_instruction=user_input,
image_url=image_b64,
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
partial_text = ""
message_placeholder = st.empty()
# Render partial text as it arrives
async for chunk in result.stream_text(delta=True):
partial_text += chunk
message_placeholder.markdown(partial_text)
# Now that the stream is finished, we have a final result.
# Add new messages from this run, excluding user-prompt messages
filtered_messages = [msg for msg in result.new_messages()
if not (hasattr(msg, 'parts') and
any(part.part_kind == 'user-prompt' for part in msg.parts))]
st.session_state.messages.extend(filtered_messages)
# Add the final response to the messages
st.session_state.messages.append(
ModelResponse(parts=[TextPart(content=partial_text)])
)
st.rerun()
async def main():
st.title("Conversational Image Editor")
if "openai_model" not in st.session_state:
st.session_state["openai_model"] = "gpt-4o"
if "messages" not in st.session_state:
st.session_state.messages = []
if "image" not in st.session_state:
st.session_state.image = None
chat_col, image_col = st.columns(2)
with chat_col:
# Display all messages from the conversation so far
# Each message is either a ModelRequest or ModelResponse.
# We iterate over their parts to decide how to display them.
for msg in st.session_state.messages:
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse) or isinstance(msg, ToolCallPart) or isinstance(msg, ToolReturnPart):
for part in msg.parts:
display_message_part(part)
with image_col:
st.session_state.image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if st.session_state.image:
st.image(st.session_state.image)
else:
st.write("Upload an image to get started")
# Chat input for the user
user_input = st.chat_input("What would you like to edit your image?", disabled=not st.session_state.image)
if user_input and st.session_state.image:
st.session_state.messages.append(
ModelRequest(parts=[UserPromptPart(content=user_input)])
)
# Display the assistant's partial response while streaming
with st.chat_message("assistant"):
# Actually run the agent now, streaming the text
image_url = image_path_to_uri(st.session_state.image)
await run_agent(user_input, image_url)
if __name__ == "__main__":
asyncio.run(main())