chat-image-edit / agent.py
simonlee-cb's picture
feat: working magic replace
9e822e4
raw
history blame
5.03 kB
from openai import OpenAI
import streamlit as st
from typing import TypedDict, Literal, Optional
from pydantic import BaseModel
import base64
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
SystemPromptPart,
UserPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
RetryPromptPart,
ModelMessagesTypeAdapter
)
import asyncio
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
from src.hopter.client import Hopter, Environment
import os
from src.services.generate_mask import GenerateMaskService
from dotenv import load_dotenv
load_dotenv()
hopter = Hopter(
api_key=os.getenv("HOPTER_API_KEY"),
environment=Environment.STAGING
)
mask_service = GenerateMaskService(hopter=hopter)
class ChatMessage(TypedDict):
"""Format of messages sent to the browser/API."""
role: Literal['user', 'model']
timestamp: str
content: str
def display_message_part(part):
"""
Display a single part of a message in the Streamlit UI.
Customize how you display system prompts, user prompts,
tool calls, tool returns, etc.
"""
# system-prompt
if part.part_kind == 'system-prompt':
with st.chat_message("system"):
st.markdown(f"**System**: {part.content}")
# user-prompt
elif part.part_kind == 'user-prompt':
with st.chat_message("user"):
st.markdown(part.content)
# text
elif part.part_kind == 'text':
with st.chat_message("assistant"):
st.markdown(part.content)
async def run_agent(user_input: str, image_b64: str):
messages = [
{
"type": "text",
"text": user_input
},
{
"type": "image_url",
"image_url": {
"url": image_b64
}
}
]
deps = ImageEditDeps(
edit_instruction=user_input,
image_url=image_b64,
hopter_client=hopter,
mask_service=mask_service
)
async with mask_generation_agent.run_stream(
messages,
deps=deps
) as result:
partial_text = ""
message_placeholder = st.empty()
# Render partial text as it arrives
async for chunk in result.stream_text(delta=True):
partial_text += chunk
message_placeholder.markdown(partial_text)
# Now that the stream is finished, we have a final result.
# Add new messages from this run, excluding user-prompt messages
filtered_messages = [msg for msg in result.new_messages()
if not (hasattr(msg, 'parts') and
any(part.part_kind == 'user-prompt' for part in msg.parts))]
st.session_state.messages.extend(filtered_messages)
# Add the final response to the messages
st.session_state.messages.append(
ModelResponse(parts=[TextPart(content=partial_text)])
)
async def main():
st.title("ChatGPT-like clone")
def encode_image(uploaded_file):
# Read the file directly from the UploadedFile object
return base64.b64encode(uploaded_file.read()).decode("utf-8")
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if "openai_model" not in st.session_state:
st.session_state["openai_model"] = "gpt-4o"
if "messages" not in st.session_state:
st.session_state.messages = []
if "image" not in st.session_state:
st.session_state.image = None
# Display all messages from the conversation so far
# Each message is either a ModelRequest or ModelResponse.
# We iterate over their parts to decide how to display them.
for msg in st.session_state.messages:
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
for part in msg.parts:
display_message_part(part)
# Chat input for the user
user_input = st.chat_input("What would you like to edit your image?")
if user_input:
if image is not None:
st.session_state.image = image
# We append a new request to the conversation explicitly
st.session_state.messages.append(
ModelRequest(parts=[UserPromptPart(content=user_input)])
)
# Display user prompt in the UI
with st.chat_message("user"):
if st.session_state.image:
st.image(st.session_state.image)
st.markdown(user_input)
# Display the assistant's partial response while streaming
with st.chat_message("assistant"):
# Actually run the agent now, streaming the text
if st.session_state.image:
image_data = encode_image(st.session_state.image)
image_url = f"data:image/jpeg;base64,{image_data}"
await run_agent(user_input, image_url)
else:
await run_agent(user_input)
if __name__ == "__main__":
asyncio.run(main())