simonlee-cb commited on
Commit
d2eb85a
·
1 Parent(s): a21dee1

feat: hook up streamlit to agent

Browse files
agent.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import streamlit as st
3
+ from typing import TypedDict, Literal, Optional
4
+ from pydantic import BaseModel
5
+ import base64
6
+ from pydantic_ai.messages import (
7
+ ModelMessage,
8
+ ModelRequest,
9
+ ModelResponse,
10
+ SystemPromptPart,
11
+ UserPromptPart,
12
+ TextPart,
13
+ ToolCallPart,
14
+ ToolReturnPart,
15
+ RetryPromptPart,
16
+ ModelMessagesTypeAdapter
17
+ )
18
+ import asyncio
19
+ from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
20
+
21
+ class ChatMessage(TypedDict):
22
+ """Format of messages sent to the browser/API."""
23
+
24
+ role: Literal['user', 'model']
25
+ timestamp: str
26
+ content: str
27
+
28
+
29
+ def display_message_part(part):
30
+ """
31
+ Display a single part of a message in the Streamlit UI.
32
+ Customize how you display system prompts, user prompts,
33
+ tool calls, tool returns, etc.
34
+ """
35
+ # system-prompt
36
+ if part.part_kind == 'system-prompt':
37
+ with st.chat_message("system"):
38
+ st.markdown(f"**System**: {part.content}")
39
+ # user-prompt
40
+ elif part.part_kind == 'user-prompt':
41
+ with st.chat_message("user"):
42
+ st.markdown(part.content)
43
+ # text
44
+ elif part.part_kind == 'text':
45
+ with st.chat_message("assistant"):
46
+ st.markdown(part.content)
47
+
48
+ async def run_agent(user_input: str, image_b64: str):
49
+ messages = [
50
+ {
51
+ "type": "text",
52
+ "text": user_input
53
+ },
54
+ {
55
+ "type": "image_url",
56
+ "image_url": {
57
+ "url": image_b64
58
+ }
59
+ }
60
+ ]
61
+ deps = ImageEditDeps(
62
+ edit_instruction=user_input,
63
+ image_url=image_b64
64
+ )
65
+ async with mask_generation_agent.run_stream(
66
+ messages,
67
+ deps=deps
68
+ ) as result:
69
+ partial_text = ""
70
+ message_placeholder = st.empty()
71
+
72
+ # Render partial text as it arrives
73
+ async for chunk in result.stream_text(delta=True):
74
+ partial_text += chunk
75
+ message_placeholder.markdown(partial_text)
76
+
77
+ # Now that the stream is finished, we have a final result.
78
+ # Add new messages from this run, excluding user-prompt messages
79
+ filtered_messages = [msg for msg in result.new_messages()
80
+ if not (hasattr(msg, 'parts') and
81
+ any(part.part_kind == 'user-prompt' for part in msg.parts))]
82
+ st.session_state.messages.extend(filtered_messages)
83
+
84
+ # Add the final response to the messages
85
+ st.session_state.messages.append(
86
+ ModelResponse(parts=[TextPart(content=partial_text)])
87
+ )
88
+
89
+ async def main():
90
+ st.title("ChatGPT-like clone")
91
+
92
+ def encode_image(uploaded_file):
93
+ # Read the file directly from the UploadedFile object
94
+ return base64.b64encode(uploaded_file.read()).decode("utf-8")
95
+
96
+ image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
97
+
98
+ if "openai_model" not in st.session_state:
99
+ st.session_state["openai_model"] = "gpt-4o"
100
+
101
+ if "messages" not in st.session_state:
102
+ st.session_state.messages = []
103
+
104
+ if "image" not in st.session_state:
105
+ st.session_state.image = None
106
+
107
+ # Display all messages from the conversation so far
108
+ # Each message is either a ModelRequest or ModelResponse.
109
+ # We iterate over their parts to decide how to display them.
110
+ for msg in st.session_state.messages:
111
+ if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
112
+ for part in msg.parts:
113
+ display_message_part(part)
114
+
115
+ # Chat input for the user
116
+ user_input = st.chat_input("What would you like to edit your image?")
117
+
118
+ if user_input:
119
+ if image is not None:
120
+ st.session_state.image = image
121
+
122
+ # We append a new request to the conversation explicitly
123
+ st.session_state.messages.append(
124
+ ModelRequest(parts=[UserPromptPart(content=user_input)])
125
+ )
126
+
127
+ # Display user prompt in the UI
128
+ with st.chat_message("user"):
129
+ if st.session_state.image:
130
+ st.image(st.session_state.image)
131
+ st.markdown(user_input)
132
+
133
+ # Display the assistant's partial response while streaming
134
+ with st.chat_message("assistant"):
135
+ # Actually run the agent now, streaming the text
136
+ if st.session_state.image:
137
+ image_data = encode_image(st.session_state.image)
138
+ image_url = f"data:image/jpeg;base64,{image_data}"
139
+ await run_agent(user_input, image_url)
140
+ else:
141
+ await run_agent(user_input)
142
+
143
+
144
+ if __name__ == "__main__":
145
+ asyncio.run(main())
src/agents/{mask-generation-agent.py → mask_generation_agent.py} RENAMED
File without changes