import streamlit as st from typing import TypedDict, Literal from pydantic_ai.messages import ( ModelRequest, ModelResponse, UserPromptPart, TextPart, ToolCallPart, ToolReturnPart, ) import asyncio from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps from src.hopter.client import Hopter, Environment import os from src.services.generate_mask import GenerateMaskService from dotenv import load_dotenv from src.utils import image_path_to_uri load_dotenv() st.set_page_config( page_title="Conversational Image Editor", page_icon="🧊", layout="wide", initial_sidebar_state="collapsed" ) hopter = Hopter( api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING ) mask_service = GenerateMaskService(hopter=hopter) user_msg_input_key = "input_user_msg" class ChatMessage(TypedDict): """Format of messages sent to the browser/API.""" role: Literal['user', 'model'] timestamp: str content: str def display_message_part(part): """ Display a single part of a message in the Streamlit UI. Customize how you display system prompts, user prompts, tool calls, tool returns, etc. """ # system-prompt if part.part_kind == 'system-prompt': with st.chat_message("system"): st.markdown(f"**System**: {part.content}") # user-prompt elif part.part_kind == 'user-prompt': with st.chat_message("user"): st.markdown(part.content) # text elif part.part_kind == 'text': with st.chat_message("assistant"): st.markdown(part.content) # tool call elif part.part_kind == 'tool-call': with st.chat_message("assistant"): st.markdown(f"**{part.tool_name}**: {part.args}") # tool return elif part.part_kind == 'tool-return': with st.chat_message("assistant"): st.markdown(f"**{part.tool_name}**: {part.content}") async def run_agent(user_input: str, image_b64: str): messages = [ { "type": "text", "text": user_input }, { "type": "image_url", "image_url": { "url": image_b64 } } ] deps = ImageEditDeps( edit_instruction=user_input, image_url=image_b64, hopter_client=hopter, mask_service=mask_service ) async with mask_generation_agent.run_stream( messages, deps=deps ) as result: partial_text = "" message_placeholder = st.empty() # Render partial text as it arrives async for chunk in result.stream_text(delta=True): partial_text += chunk message_placeholder.markdown(partial_text) # Now that the stream is finished, we have a final result. # Add new messages from this run, excluding user-prompt messages filtered_messages = [msg for msg in result.new_messages() if not (hasattr(msg, 'parts') and any(part.part_kind == 'user-prompt' for part in msg.parts))] st.session_state.messages.extend(filtered_messages) # Add the final response to the messages st.session_state.messages.append( ModelResponse(parts=[TextPart(content=partial_text)]) ) st.rerun() async def main(): st.title("Conversational Image Editor") if "openai_model" not in st.session_state: st.session_state["openai_model"] = "gpt-4o" if "messages" not in st.session_state: st.session_state.messages = [] if "image" not in st.session_state: st.session_state.image = None chat_col, image_col = st.columns(2) with chat_col: # Display all messages from the conversation so far # Each message is either a ModelRequest or ModelResponse. # We iterate over their parts to decide how to display them. for msg in st.session_state.messages: if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse) or isinstance(msg, ToolCallPart) or isinstance(msg, ToolReturnPart): for part in msg.parts: display_message_part(part) with image_col: st.session_state.image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) if st.session_state.image: st.image(st.session_state.image) else: st.write("Upload an image to get started") # Chat input for the user user_input = st.chat_input("What would you like to edit your image?", disabled=not st.session_state.image) if user_input and st.session_state.image: st.session_state.messages.append( ModelRequest(parts=[UserPromptPart(content=user_input)]) ) # Display the assistant's partial response while streaming with st.chat_message("assistant"): # Actually run the agent now, streaming the text image_url = image_path_to_uri(st.session_state.image) await run_agent(user_input, image_url) if __name__ == "__main__": asyncio.run(main())