File size: 5,152 Bytes
d2eb85a
c55fe6a
d2eb85a
 
 
 
 
 
 
 
 
 
9e822e4
 
 
 
c55fe6a
9e822e4
 
c55fe6a
 
 
 
 
 
 
9e822e4
 
 
 
 
c55fe6a
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55fe6a
 
 
 
 
 
 
 
 
 
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e822e4
 
 
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55fe6a
d2eb85a
 
c55fe6a
d2eb85a
 
 
 
 
 
 
 
 
 
c55fe6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2eb85a
 
c55fe6a
 
d2eb85a
 
 
 
 
 
 
c55fe6a
 
d2eb85a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import streamlit as st
from typing import TypedDict, Literal
from pydantic_ai.messages import (
    ModelRequest,
    ModelResponse,
    UserPromptPart,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
)
import asyncio
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
from src.hopter.client import Hopter, Environment
import os
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
from src.utils import image_path_to_uri
load_dotenv()

st.set_page_config(
    page_title="Conversational Image Editor",
    page_icon="🧊",
    layout="wide",
    initial_sidebar_state="collapsed"
)

hopter = Hopter(
    api_key=os.getenv("HOPTER_API_KEY"),
    environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
user_msg_input_key = "input_user_msg"

class ChatMessage(TypedDict):
    """Format of messages sent to the browser/API."""

    role: Literal['user', 'model']
    timestamp: str
    content: str


def display_message_part(part):
    """
    Display a single part of a message in the Streamlit UI.
    Customize how you display system prompts, user prompts,
    tool calls, tool returns, etc.
    """
    # system-prompt
    if part.part_kind == 'system-prompt':
        with st.chat_message("system"):
            st.markdown(f"**System**: {part.content}")
    # user-prompt
    elif part.part_kind == 'user-prompt':
        with st.chat_message("user"):
            st.markdown(part.content)
    # text
    elif part.part_kind == 'text':
        with st.chat_message("assistant"):
            st.markdown(part.content) 

    # tool call
    elif part.part_kind == 'tool-call':
        with st.chat_message("assistant"):
            st.markdown(f"**{part.tool_name}**: {part.args}")

    # tool return
    elif part.part_kind == 'tool-return':
        with st.chat_message("assistant"):
            st.markdown(f"**{part.tool_name}**: {part.content}")

async def run_agent(user_input: str, image_b64: str):
    messages = [
        {
            "type": "text",
            "text": user_input
        },
        {
            "type": "image_url",
            "image_url": {
                "url": image_b64
            }
        }
    ]
    deps = ImageEditDeps(
        edit_instruction=user_input,
        image_url=image_b64,
        hopter_client=hopter,
        mask_service=mask_service
    )
    async with mask_generation_agent.run_stream(
        messages,
        deps=deps
    ) as result:
        partial_text = ""
        message_placeholder = st.empty()

        # Render partial text as it arrives
        async for chunk in result.stream_text(delta=True):
            partial_text += chunk
            message_placeholder.markdown(partial_text)

        # Now that the stream is finished, we have a final result.
        # Add new messages from this run, excluding user-prompt messages
        filtered_messages = [msg for msg in result.new_messages() 
                            if not (hasattr(msg, 'parts') and 
                                    any(part.part_kind == 'user-prompt' for part in msg.parts))]
        st.session_state.messages.extend(filtered_messages)

        # Add the final response to the messages
        st.session_state.messages.append(
            ModelResponse(parts=[TextPart(content=partial_text)])
        )
    st.rerun()

async def main():
    st.title("Conversational Image Editor")

    if "openai_model" not in st.session_state:
        st.session_state["openai_model"] = "gpt-4o"

    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "image" not in st.session_state:
        st.session_state.image = None

    chat_col, image_col = st.columns(2)
    with chat_col:
        # Display all messages from the conversation so far
        # Each message is either a ModelRequest or ModelResponse.
        # We iterate over their parts to decide how to display them.
        for msg in st.session_state.messages:
            if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse) or isinstance(msg, ToolCallPart) or isinstance(msg, ToolReturnPart):
                for part in msg.parts:
                    display_message_part(part)

    with image_col:
        st.session_state.image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
        if st.session_state.image:
            st.image(st.session_state.image)
        else:
            st.write("Upload an image to get started")

    # Chat input for the user
    user_input = st.chat_input("What would you like to edit your image?", disabled=not st.session_state.image)
    if user_input and st.session_state.image:
        st.session_state.messages.append(
            ModelRequest(parts=[UserPromptPart(content=user_input)])
        )

        # Display the assistant's partial response while streaming
        with st.chat_message("assistant"):
            # Actually run the agent now, streaming the text
            image_url = image_path_to_uri(st.session_state.image)
            await run_agent(user_input, image_url)

if __name__ == "__main__":
    asyncio.run(main())