Spaces:
Running
Running
import streamlit as st | |
from typing import TypedDict, Literal | |
from pydantic_ai.messages import ( | |
ModelRequest, | |
ModelResponse, | |
UserPromptPart, | |
TextPart, | |
ToolCallPart, | |
ToolReturnPart, | |
) | |
import asyncio | |
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps | |
from src.hopter.client import Hopter, Environment | |
import os | |
from src.services.generate_mask import GenerateMaskService | |
from dotenv import load_dotenv | |
from src.utils import image_path_to_uri | |
load_dotenv() | |
st.set_page_config( | |
page_title="Conversational Image Editor", | |
page_icon="🧊", | |
layout="wide", | |
initial_sidebar_state="collapsed" | |
) | |
hopter = Hopter( | |
api_key=os.getenv("HOPTER_API_KEY"), | |
environment=Environment.STAGING | |
) | |
mask_service = GenerateMaskService(hopter=hopter) | |
user_msg_input_key = "input_user_msg" | |
class ChatMessage(TypedDict): | |
"""Format of messages sent to the browser/API.""" | |
role: Literal['user', 'model'] | |
timestamp: str | |
content: str | |
def display_message_part(part): | |
""" | |
Display a single part of a message in the Streamlit UI. | |
Customize how you display system prompts, user prompts, | |
tool calls, tool returns, etc. | |
""" | |
# system-prompt | |
if part.part_kind == 'system-prompt': | |
with st.chat_message("system"): | |
st.markdown(f"**System**: {part.content}") | |
# user-prompt | |
elif part.part_kind == 'user-prompt': | |
with st.chat_message("user"): | |
st.markdown(part.content) | |
# text | |
elif part.part_kind == 'text': | |
with st.chat_message("assistant"): | |
st.markdown(part.content) | |
# tool call | |
elif part.part_kind == 'tool-call': | |
with st.chat_message("assistant"): | |
st.markdown(f"**{part.tool_name}**: {part.args}") | |
# tool return | |
elif part.part_kind == 'tool-return': | |
with st.chat_message("assistant"): | |
st.markdown(f"**{part.tool_name}**: {part.content}") | |
async def run_agent(user_input: str, image_b64: str): | |
messages = [ | |
{ | |
"type": "text", | |
"text": user_input | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": image_b64 | |
} | |
} | |
] | |
deps = ImageEditDeps( | |
edit_instruction=user_input, | |
image_url=image_b64, | |
hopter_client=hopter, | |
mask_service=mask_service | |
) | |
async with mask_generation_agent.run_stream( | |
messages, | |
deps=deps | |
) as result: | |
partial_text = "" | |
message_placeholder = st.empty() | |
# Render partial text as it arrives | |
async for chunk in result.stream_text(delta=True): | |
partial_text += chunk | |
message_placeholder.markdown(partial_text) | |
# Now that the stream is finished, we have a final result. | |
# Add new messages from this run, excluding user-prompt messages | |
filtered_messages = [msg for msg in result.new_messages() | |
if not (hasattr(msg, 'parts') and | |
any(part.part_kind == 'user-prompt' for part in msg.parts))] | |
st.session_state.messages.extend(filtered_messages) | |
# Add the final response to the messages | |
st.session_state.messages.append( | |
ModelResponse(parts=[TextPart(content=partial_text)]) | |
) | |
st.rerun() | |
async def main(): | |
st.title("Conversational Image Editor") | |
if "openai_model" not in st.session_state: | |
st.session_state["openai_model"] = "gpt-4o" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "image" not in st.session_state: | |
st.session_state.image = None | |
chat_col, image_col = st.columns(2) | |
with chat_col: | |
# Display all messages from the conversation so far | |
# Each message is either a ModelRequest or ModelResponse. | |
# We iterate over their parts to decide how to display them. | |
for msg in st.session_state.messages: | |
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse) or isinstance(msg, ToolCallPart) or isinstance(msg, ToolReturnPart): | |
for part in msg.parts: | |
display_message_part(part) | |
with image_col: | |
st.session_state.image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
if st.session_state.image: | |
st.image(st.session_state.image) | |
else: | |
st.write("Upload an image to get started") | |
# Chat input for the user | |
user_input = st.chat_input("What would you like to edit your image?", disabled=not st.session_state.image) | |
if user_input and st.session_state.image: | |
st.session_state.messages.append( | |
ModelRequest(parts=[UserPromptPart(content=user_input)]) | |
) | |
# Display the assistant's partial response while streaming | |
with st.chat_message("assistant"): | |
# Actually run the agent now, streaming the text | |
image_url = image_path_to_uri(st.session_state.image) | |
await run_agent(user_input, image_url) | |
if __name__ == "__main__": | |
asyncio.run(main()) |