Spaces:
Running
Running
import gradio as gr | |
from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps | |
import os | |
from src.hopter.client import Hopter, Environment | |
from src.services.generate_mask import GenerateMaskService | |
from dotenv import load_dotenv | |
from src.utils import image_path_to_uri, upload_image | |
from pydantic_ai.messages import ( | |
ToolCallPart, | |
ToolReturnPart | |
) | |
from src.agents.mask_generation_agent import EditImageResult | |
from pydantic_ai.agent import Agent | |
from pydantic_ai.models.openai import OpenAIModel | |
model = OpenAIModel( | |
"gpt-4o", | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
simple_agent = Agent( | |
model, | |
system_prompt="You are a helpful assistant that can answer questions and help with tasks.", | |
deps_type=ImageEditDeps | |
) | |
load_dotenv() | |
def build_user_message(chat_input): | |
text = chat_input["text"] | |
images = chat_input["files"] | |
messages = [ | |
{ | |
"role": "user", | |
"content": text | |
} | |
] | |
if images: | |
messages.extend([ | |
{ | |
"role": "user", | |
"content": {"path": image} | |
} | |
for image in images | |
]) | |
return messages | |
def build_messages_for_agent(chat_input, past_messages): | |
# filter out image messages from past messages to save on tokens | |
messages = past_messages | |
# add the user's text message | |
if chat_input["text"]: | |
messages.append({ | |
"type": "text", | |
"text": chat_input["text"] | |
}) | |
# add the user's image message | |
files = chat_input.get("files", []) | |
image_url = upload_image(files[0]) if files else None | |
if image_url: | |
messages.append({ | |
"type": "image_url", | |
"image_url": {"url": image_url} | |
}) | |
return messages | |
def select_example(x: gr.SelectData, chat_input): | |
chat_input["text"] = x.value["text"] | |
chat_input["files"] = x.value["files"] | |
return chat_input | |
async def stream_from_agent(chat_input, chatbot, past_messages, current_image): | |
# Prepare messages for the UI | |
chatbot.extend(build_user_message(chat_input)) | |
yield {"text": "", "files": []}, chatbot, gr.skip, gr.skip() | |
# Prepare messages for the agent | |
text = chat_input["text"] | |
files = chat_input.get("files", []) | |
image_url = upload_image(files[0]) if files else None | |
messages = [ | |
{ | |
"type": "text", | |
"text": text | |
}, | |
] | |
if image_url: | |
messages.append( | |
{"type": "image_url", "image_url": {"url": image_url}} | |
) | |
current_image = image_url | |
# Dependencies | |
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING) | |
mask_service = GenerateMaskService(hopter=hopter) | |
deps = ImageEditDeps( | |
edit_instruction=text, | |
image_url=current_image, | |
hopter_client=hopter, | |
mask_service=mask_service | |
) | |
# Run the agent | |
async with mask_generation_agent.run_stream( | |
messages, | |
deps=deps, | |
message_history=past_messages | |
) as result: | |
for message in result.new_messages(): | |
for call in message.parts: | |
if isinstance(call, ToolCallPart): | |
call_args = ( | |
call.args.args_json | |
if hasattr(call.args, 'args_json') | |
else call.args | |
) | |
metadata = { | |
'title': f'🛠️ Using {call.tool_name}', | |
} | |
# set the tool call id so that when the tool returns | |
# we can find this message and update with the result | |
if call.tool_call_id is not None: | |
metadata['id'] = call.tool_call_id | |
# Create a tool call message to show on the UI | |
gr_message = { | |
'role': 'assistant', | |
'content': 'Parameters: ' + call_args, | |
'metadata': metadata, | |
} | |
chatbot.append(gr_message) | |
if isinstance(call, ToolReturnPart): | |
for gr_message in chatbot: | |
# Skip messages without metadata | |
if not gr_message.get('metadata'): | |
continue | |
if gr_message['metadata'].get('id', '') == call.tool_call_id: | |
if isinstance(call.content, EditImageResult): | |
chatbot.append({ | |
"role": "assistant", | |
"content": gr.Image(call.content.edited_image_url), | |
"files": [call.content.edited_image_url] | |
}) | |
else: | |
gr_message['content'] += ( | |
f'\nOutput: {call.content}' | |
) | |
yield gr.skip(), chatbot, gr.skip(), gr.skip() | |
chatbot.append({'role': 'assistant', 'content': ''}) | |
async for message in result.stream_text(): | |
chatbot[-1]['content'] = message | |
yield gr.skip(), chatbot, gr.skip(), gr.skip() | |
past_messages = result.all_messages() | |
yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%"> | |
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto"> | |
<div> | |
<h1 style="margin: 0 0 1rem 0">Image Editing Assistant</h1> | |
<h3 style="margin: 0 0 0.5rem 0"> | |
This assistant edits images according to your instructions. | |
</h3> | |
</div> | |
</div> | |
""" | |
) | |
current_image = gr.State(None) | |
past_messages = gr.State([]) | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label='Image Editing Assistant', | |
type='messages', | |
avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'), | |
examples=[ | |
{ | |
"text": "Remove the person in the image", | |
"files": [ | |
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg" | |
] | |
}, | |
{ | |
"text": "Change all the balloons to red in the image", | |
"files": [ | |
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg" | |
] | |
}, | |
{ | |
"text": "Change coffee to a glass of water", | |
"files": [ | |
"https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg" | |
] | |
}, | |
{ | |
"text": "ENHANCE!", | |
"files": [ | |
"https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg" | |
] | |
} | |
] | |
) | |
with gr.Row(): | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_count="single", | |
show_label=False, | |
placeholder='How would you like to edit this image?', | |
sources=["upload"] | |
) | |
generation = chat_input.submit( | |
stream_from_agent, | |
inputs=[chat_input, chatbot, past_messages, current_image], | |
outputs=[chat_input, chatbot, past_messages, current_image], | |
) | |
chatbot.example_select( | |
select_example, | |
inputs=[chat_input], | |
outputs=[chat_input], | |
).then( | |
stream_from_agent, | |
inputs=[chat_input, chatbot, past_messages, current_image], | |
outputs=[chat_input, chatbot, past_messages, current_image], | |
) | |
if __name__ == '__main__': | |
demo.launch() |