Spaces:
Running
Running
File size: 5,152 Bytes
d2eb85a c55fe6a d2eb85a 9e822e4 c55fe6a 9e822e4 c55fe6a 9e822e4 c55fe6a d2eb85a c55fe6a d2eb85a 9e822e4 d2eb85a c55fe6a d2eb85a c55fe6a d2eb85a c55fe6a d2eb85a c55fe6a d2eb85a c55fe6a d2eb85a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
from typing import TypedDict, Literal
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
UserPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
)
import asyncio
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
from src.hopter.client import Hopter, Environment
import os
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import image_path_to_uri
load_dotenv()
st.set_page_config(
page_title="Conversational Image Editor",
page_icon="🧊",
layout="wide",
initial_sidebar_state="collapsed"
)
hopter = Hopter(
api_key=os.getenv("HOPTER_API_KEY"),
environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
user_msg_input_key = "input_user_msg"
class ChatMessage(TypedDict):
"""Format of messages sent to the browser/API."""
role: Literal['user', 'model']
timestamp: str
content: str
def display_message_part(part):
"""
Display a single part of a message in the Streamlit UI.
Customize how you display system prompts, user prompts,
tool calls, tool returns, etc.
"""
# system-prompt
if part.part_kind == 'system-prompt':
with st.chat_message("system"):
st.markdown(f"**System**: {part.content}")
# user-prompt
elif part.part_kind == 'user-prompt':
with st.chat_message("user"):
st.markdown(part.content)
# text
elif part.part_kind == 'text':
with st.chat_message("assistant"):
st.markdown(part.content)
# tool call
elif part.part_kind == 'tool-call':
with st.chat_message("assistant"):
st.markdown(f"**{part.tool_name}**: {part.args}")
# tool return
elif part.part_kind == 'tool-return':
with st.chat_message("assistant"):
st.markdown(f"**{part.tool_name}**: {part.content}")
async def run_agent(user_input: str, image_b64: str):
messages = [
{
"type": "text",
"text": user_input
},
{
"type": "image_url",
"image_url": {
"url": image_b64
}
}
]
deps = ImageEditDeps(
edit_instruction=user_input,
image_url=image_b64,
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
partial_text = ""
message_placeholder = st.empty()
# Render partial text as it arrives
async for chunk in result.stream_text(delta=True):
partial_text += chunk
message_placeholder.markdown(partial_text)
# Now that the stream is finished, we have a final result.
# Add new messages from this run, excluding user-prompt messages
filtered_messages = [msg for msg in result.new_messages()
if not (hasattr(msg, 'parts') and
any(part.part_kind == 'user-prompt' for part in msg.parts))]
st.session_state.messages.extend(filtered_messages)
# Add the final response to the messages
st.session_state.messages.append(
ModelResponse(parts=[TextPart(content=partial_text)])
)
st.rerun()
async def main():
st.title("Conversational Image Editor")
if "openai_model" not in st.session_state:
st.session_state["openai_model"] = "gpt-4o"
if "messages" not in st.session_state:
st.session_state.messages = []
if "image" not in st.session_state:
st.session_state.image = None
chat_col, image_col = st.columns(2)
with chat_col:
# Display all messages from the conversation so far
# Each message is either a ModelRequest or ModelResponse.
# We iterate over their parts to decide how to display them.
for msg in st.session_state.messages:
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse) or isinstance(msg, ToolCallPart) or isinstance(msg, ToolReturnPart):
for part in msg.parts:
display_message_part(part)
with image_col:
st.session_state.image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if st.session_state.image:
st.image(st.session_state.image)
else:
st.write("Upload an image to get started")
# Chat input for the user
user_input = st.chat_input("What would you like to edit your image?", disabled=not st.session_state.image)
if user_input and st.session_state.image:
st.session_state.messages.append(
ModelRequest(parts=[UserPromptPart(content=user_input)])
)
# Display the assistant's partial response while streaming
with st.chat_message("assistant"):
# Actually run the agent now, streaming the text
image_url = image_path_to_uri(st.session_state.image)
await run_agent(user_input, image_url)
if __name__ == "__main__":
asyncio.run(main()) |