simonlee-cb commited on
Commit
6962136
·
1 Parent(s): 14abb7f

fix some issues

Browse files
Files changed (2) hide show
  1. gradio_app.py +56 -24
  2. src/agents/mask_generation_agent.py +10 -17
gradio_app.py CHANGED
@@ -12,7 +12,6 @@ from pydantic_ai.messages import (
12
  from src.agents.mask_generation_agent import EditImageResult
13
  from pydantic_ai.agent import Agent
14
  from pydantic_ai.models.openai import OpenAIModel
15
-
16
  model = OpenAIModel(
17
  "gpt-4o",
18
  api_key=os.environ.get("OPENAI_API_KEY"),
@@ -45,37 +44,65 @@ def build_user_message(chat_input):
45
  ])
46
  return messages
47
 
48
- async def stream_from_agent(chat_input, chatbot, past_messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  chatbot.extend(build_user_message(chat_input))
50
- # Clear the input immediately after submission
51
- yield {"text": "", "files": []}, chatbot, gr.skip
52
 
53
- # for agent
54
  text = chat_input["text"]
55
- images = [image_path_to_uri(image) for image in chat_input["files"]]
 
56
  messages = [
57
  {
58
  "type": "text",
59
  "text": text
60
  },
61
  ]
62
- if images:
63
- messages.extend([
64
- {"type": "image_url", "image_url": {"url": image}}
65
- for image in images
66
- ])
67
 
 
68
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
69
  mask_service = GenerateMaskService(hopter=hopter)
70
  deps = ImageEditDeps(
71
  edit_instruction=text,
72
- image_url=images[0],
73
  hopter_client=hopter,
74
  mask_service=mask_service
75
  )
 
76
  async with mask_generation_agent.run_stream(
77
  messages,
78
- deps=deps
79
  ) as result:
80
  for message in result.new_messages():
81
  for call in message.parts:
@@ -88,9 +115,12 @@ async def stream_from_agent(chat_input, chatbot, past_messages):
88
  metadata = {
89
  'title': f'🛠️ Using {call.tool_name}',
90
  }
 
 
91
  if call.tool_call_id is not None:
92
  metadata['id'] = call.tool_call_id
93
 
 
94
  gr_message = {
95
  'role': 'assistant',
96
  'content': 'Parameters: ' + call_args,
@@ -99,10 +129,11 @@ async def stream_from_agent(chat_input, chatbot, past_messages):
99
  chatbot.append(gr_message)
100
  if isinstance(call, ToolReturnPart):
101
  for gr_message in chatbot:
102
- if (
103
- gr_message.get('metadata', {}).get('id', '')
104
- == call.tool_call_id
105
- ):
 
106
  if isinstance(call.content, EditImageResult):
107
  chatbot.append({
108
  "role": "assistant",
@@ -113,15 +144,15 @@ async def stream_from_agent(chat_input, chatbot, past_messages):
113
  gr_message['content'] += (
114
  f'\nOutput: {call.content}'
115
  )
116
- yield gr.skip(), chatbot, gr.skip()
117
 
118
  chatbot.append({'role': 'assistant', 'content': ''})
119
  async for message in result.stream_text():
120
  chatbot[-1]['content'] = message
121
- yield gr.skip(), chatbot, gr.skip()
122
  past_messages = result.all_messages()
123
 
124
- yield gr.Textbox(interactive=True), gr.skip(), past_messages
125
 
126
  with gr.Blocks() as demo:
127
  gr.HTML(
@@ -138,6 +169,7 @@ with gr.Blocks() as demo:
138
  """
139
  )
140
 
 
141
  past_messages = gr.State([])
142
  chatbot = gr.Chatbot(
143
  label='Image Editing Assistant',
@@ -147,15 +179,15 @@ with gr.Blocks() as demo:
147
  with gr.Row():
148
  chat_input = gr.MultimodalTextbox(
149
  interactive=True,
150
- file_count="multiple",
151
  show_label=False,
152
  placeholder='How would you like to edit this image?',
153
- sources=["upload", "microphone"]
154
  )
155
  generation = chat_input.submit(
156
  stream_from_agent,
157
- inputs=[chat_input, chatbot, past_messages],
158
- outputs=[chat_input, chatbot, past_messages],
159
  )
160
 
161
  if __name__ == '__main__':
 
12
  from src.agents.mask_generation_agent import EditImageResult
13
  from pydantic_ai.agent import Agent
14
  from pydantic_ai.models.openai import OpenAIModel
 
15
  model = OpenAIModel(
16
  "gpt-4o",
17
  api_key=os.environ.get("OPENAI_API_KEY"),
 
44
  ])
45
  return messages
46
 
47
+ def build_messages_for_agent(chat_input, past_messages):
48
+ # filter out image messages from past messages to save on tokens
49
+ messages = [msg for msg in past_messages
50
+ if not (isinstance(msg, dict)
51
+ and msg.get("type") == "image_url")]
52
+
53
+ # add the user's text message
54
+ if chat_input["text"]:
55
+ messages.append({
56
+ "type": "text",
57
+ "text": chat_input["text"]
58
+ })
59
+
60
+ # add the user's image message
61
+ files = chat_input.get("files", [])
62
+ image_url = image_path_to_uri(files[0]) if files else None
63
+ if image_url:
64
+ messages.append({
65
+ "type": "image_url",
66
+ "image_url": {"url": image_url}
67
+ })
68
+
69
+ return messages
70
+
71
+
72
+ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
73
+ # Prepare messages for the UI
74
  chatbot.extend(build_user_message(chat_input))
75
+ yield {"text": "", "files": []}, chatbot, gr.skip, gr.skip()
 
76
 
77
+ # Prepare messages for the agent
78
  text = chat_input["text"]
79
+ files = chat_input.get("files", [])
80
+ image_url = image_path_to_uri(files[0]) if files else None
81
  messages = [
82
  {
83
  "type": "text",
84
  "text": text
85
  },
86
  ]
87
+ if image_url:
88
+ messages.append(
89
+ {"type": "image_url", "image_url": {"url": image_url}}
90
+ )
91
+ current_image = image_url
92
 
93
+ # Dependencies
94
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
95
  mask_service = GenerateMaskService(hopter=hopter)
96
  deps = ImageEditDeps(
97
  edit_instruction=text,
98
+ image_url=current_image,
99
  hopter_client=hopter,
100
  mask_service=mask_service
101
  )
102
+ # Run the agent
103
  async with mask_generation_agent.run_stream(
104
  messages,
105
+ deps=deps,
106
  ) as result:
107
  for message in result.new_messages():
108
  for call in message.parts:
 
115
  metadata = {
116
  'title': f'🛠️ Using {call.tool_name}',
117
  }
118
+ # set the tool call id so that when the tool returns
119
+ # we can find this message and update with the result
120
  if call.tool_call_id is not None:
121
  metadata['id'] = call.tool_call_id
122
 
123
+ # Create a tool call message to show on the UI
124
  gr_message = {
125
  'role': 'assistant',
126
  'content': 'Parameters: ' + call_args,
 
129
  chatbot.append(gr_message)
130
  if isinstance(call, ToolReturnPart):
131
  for gr_message in chatbot:
132
+ # Skip messages without metadata
133
+ if not gr_message.get('metadata'):
134
+ continue
135
+
136
+ if gr_message['metadata'].get('id', '') == call.tool_call_id:
137
  if isinstance(call.content, EditImageResult):
138
  chatbot.append({
139
  "role": "assistant",
 
144
  gr_message['content'] += (
145
  f'\nOutput: {call.content}'
146
  )
147
+ yield gr.skip(), chatbot, gr.skip(), gr.skip()
148
 
149
  chatbot.append({'role': 'assistant', 'content': ''})
150
  async for message in result.stream_text():
151
  chatbot[-1]['content'] = message
152
+ yield gr.skip(), chatbot, gr.skip(), gr.skip()
153
  past_messages = result.all_messages()
154
 
155
+ yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
156
 
157
  with gr.Blocks() as demo:
158
  gr.HTML(
 
169
  """
170
  )
171
 
172
+ current_image = gr.State(None)
173
  past_messages = gr.State([])
174
  chatbot = gr.Chatbot(
175
  label='Image Editing Assistant',
 
179
  with gr.Row():
180
  chat_input = gr.MultimodalTextbox(
181
  interactive=True,
182
+ file_count="single",
183
  show_label=False,
184
  placeholder='How would you like to edit this image?',
185
+ sources=["upload"]
186
  )
187
  generation = chat_input.submit(
188
  stream_from_agent,
189
+ inputs=[chat_input, chatbot, past_messages, current_image],
190
+ outputs=[chat_input, chatbot, past_messages, current_image],
191
  )
192
 
193
  if __name__ == '__main__':
src/agents/mask_generation_agent.py CHANGED
@@ -4,6 +4,7 @@ from dotenv import load_dotenv
4
  import os
5
  import asyncio
6
  from dataclasses import dataclass
 
7
  import logfire
8
  from src.services.generate_mask import GenerateMaskService
9
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
@@ -18,14 +19,15 @@ logfire.instrument_openai()
18
  system_prompt = """
19
  I will give you an editing instruction of the image.
20
  if the edit instruction involved modifying parts of the image, please generate a mask for it.
 
21
  """
22
 
23
  @dataclass
24
  class ImageEditDeps:
25
  edit_instruction: str
26
- image_url: str
27
  hopter_client: Hopter
28
  mask_service: GenerateMaskService
 
29
 
30
  model = OpenAIModel(
31
  "gpt-4o",
@@ -47,25 +49,16 @@ mask_generation_agent = Agent(
47
  deps_type=ImageEditDeps
48
  )
49
 
50
- # @mask_generation_agent.tool
51
- # async def generate_mask(ctx: RunContext[ImageEditDeps]) -> MaskGenerationResult:
52
- # """
53
- # Generate a mask for the image editing instruction.
54
- # """
55
- # print("Invoking generate_mask tool")
56
- # service = GenerateMaskService()
57
- # mask_instruction = await service.get_mask_generation_instruction(ctx.deps.edit_instruction, ctx.deps.image_url)
58
- # response = mask_instruction.model_dump_json(indent=4)
59
- # print(f"generate_mask tool response: {response}")
60
-
61
- # mask = await service.generate_mask(mask_instruction, ctx.deps.image_url)
62
- # print("Exiting generate_mask tool")
63
- # return MaskGenerationResult(mask_image_base64=mask)
64
-
65
  @mask_generation_agent.tool
66
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
67
  """
68
- Edit an object in the image.
 
 
 
 
 
 
69
  """
70
  edit_instruction = ctx.deps.edit_instruction
71
  image_url = ctx.deps.image_url
 
4
  import os
5
  import asyncio
6
  from dataclasses import dataclass
7
+ from typing import Optional
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
 
19
  system_prompt = """
20
  I will give you an editing instruction of the image.
21
  if the edit instruction involved modifying parts of the image, please generate a mask for it.
22
+ if images are not provided, ask the user to provide an image.
23
  """
24
 
25
  @dataclass
26
  class ImageEditDeps:
27
  edit_instruction: str
 
28
  hopter_client: Hopter
29
  mask_service: GenerateMaskService
30
+ image_url: Optional[str] = None
31
 
32
  model = OpenAIModel(
33
  "gpt-4o",
 
49
  deps_type=ImageEditDeps
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @mask_generation_agent.tool
53
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
54
  """
55
+ Use this tool to edit an object in the image. for example:
56
+ - remove the pole
57
+ - replace the dog with a cat
58
+ - change the background to a beach
59
+ - remove the person in the image
60
+ - change the hair color to red
61
+ - change the hat to a cap
62
  """
63
  edit_instruction = ctx.deps.edit_instruction
64
  image_url = ctx.deps.image_url