File size: 5,028 Bytes
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e822e4
 
 
 
 
 
 
 
 
 
 
 
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e822e4
 
 
d2eb85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from openai import OpenAI
import streamlit as st
from typing import TypedDict, Literal, Optional
from pydantic import BaseModel
import base64
from pydantic_ai.messages import (
    ModelMessage,
    ModelRequest,
    ModelResponse,
    SystemPromptPart,
    UserPromptPart,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
    RetryPromptPart,
    ModelMessagesTypeAdapter
)
import asyncio
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
from src.hopter.client import Hopter, Environment
import os
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv

load_dotenv()

hopter = Hopter(
    api_key=os.getenv("HOPTER_API_KEY"),
    environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)

class ChatMessage(TypedDict):
    """Format of messages sent to the browser/API."""

    role: Literal['user', 'model']
    timestamp: str
    content: str


def display_message_part(part):
    """
    Display a single part of a message in the Streamlit UI.
    Customize how you display system prompts, user prompts,
    tool calls, tool returns, etc.
    """
    # system-prompt
    if part.part_kind == 'system-prompt':
        with st.chat_message("system"):
            st.markdown(f"**System**: {part.content}")
    # user-prompt
    elif part.part_kind == 'user-prompt':
        with st.chat_message("user"):
            st.markdown(part.content)
    # text
    elif part.part_kind == 'text':
        with st.chat_message("assistant"):
            st.markdown(part.content) 

async def run_agent(user_input: str, image_b64: str):
    messages = [
        {
            "type": "text",
            "text": user_input
        },
        {
            "type": "image_url",
            "image_url": {
                "url": image_b64
            }
        }
    ]
    deps = ImageEditDeps(
        edit_instruction=user_input,
        image_url=image_b64,
        hopter_client=hopter,
        mask_service=mask_service
    )
    async with mask_generation_agent.run_stream(
        messages,
        deps=deps
    ) as result:
        partial_text = ""
        message_placeholder = st.empty()

        # Render partial text as it arrives
        async for chunk in result.stream_text(delta=True):
            partial_text += chunk
            message_placeholder.markdown(partial_text)

        # Now that the stream is finished, we have a final result.
        # Add new messages from this run, excluding user-prompt messages
        filtered_messages = [msg for msg in result.new_messages() 
                            if not (hasattr(msg, 'parts') and 
                                    any(part.part_kind == 'user-prompt' for part in msg.parts))]
        st.session_state.messages.extend(filtered_messages)

        # Add the final response to the messages
        st.session_state.messages.append(
            ModelResponse(parts=[TextPart(content=partial_text)])
        )

async def main():
    st.title("ChatGPT-like clone")

    def encode_image(uploaded_file):
        # Read the file directly from the UploadedFile object
        return base64.b64encode(uploaded_file.read()).decode("utf-8")

    image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])

    if "openai_model" not in st.session_state:
        st.session_state["openai_model"] = "gpt-4o"

    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "image" not in st.session_state:
        st.session_state.image = None

    # Display all messages from the conversation so far
    # Each message is either a ModelRequest or ModelResponse.
    # We iterate over their parts to decide how to display them.
    for msg in st.session_state.messages:
        if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
            for part in msg.parts:
                display_message_part(part)

    # Chat input for the user
    user_input = st.chat_input("What would you like to edit your image?")

    if user_input:
        if image is not None:
            st.session_state.image = image

        # We append a new request to the conversation explicitly
        st.session_state.messages.append(
            ModelRequest(parts=[UserPromptPart(content=user_input)])
        )
        
        # Display user prompt in the UI
        with st.chat_message("user"):
            if st.session_state.image:
                st.image(st.session_state.image)
            st.markdown(user_input)

        # Display the assistant's partial response while streaming
        with st.chat_message("assistant"):
            # Actually run the agent now, streaming the text
            if st.session_state.image:
                image_data = encode_image(st.session_state.image)
                image_url = f"data:image/jpeg;base64,{image_data}"
                await run_agent(user_input, image_url)
            else:
                await run_agent(user_input)


if __name__ == "__main__":
    asyncio.run(main())