simonlee-cb commited on
Commit
fcb8f25
·
1 Parent(s): 2a2c2ad

refactor: formatting

Browse files
app.py CHANGED
@@ -8,4 +8,4 @@ with demo.route("PicEdit"):
8
  image_edit_demo.demo.render()
9
 
10
  if __name__ == "__main__":
11
- demo.launch()
 
8
  image_edit_demo.demo.render()
9
 
10
  if __name__ == "__main__":
11
+ demo.launch()
image_edit_chat.py CHANGED
@@ -5,13 +5,10 @@ from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
  from src.utils import upload_image
8
- from pydantic_ai.messages import (
9
- ToolCallPart,
10
- ToolReturnPart
11
- )
12
  from pydantic_ai.models.openai import OpenAIModel
13
 
14
- model = OpenAIModel(
15
  "gpt-4o",
16
  api_key=os.environ.get("OPENAI_API_KEY"),
17
  )
@@ -33,76 +30,65 @@ EXAMPLES = [
33
  "text": "Replace the background to the space with stars and planets",
34
  "files": [
35
  "https://cdn.prod.website-files.com/66f230993926deadc0ac3a44/66f370d65f158cbbcfbcc532_Crossed%20Arms%20Levi%20Meir%20Clancy.jpg"
36
- ]
37
  },
38
  {
39
  "text": "Change all the balloons to red in the image",
40
  "files": [
41
  "https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
42
- ]
43
  },
44
  {
45
  "text": "Change coffee to a glass of water",
46
  "files": [
47
  "https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
48
- ]
49
  },
50
  {
51
  "text": "ENHANCE!",
52
  "files": [
53
  "https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
54
- ]
55
- }
56
  ]
57
 
58
  load_dotenv()
59
 
 
60
  def build_user_message(chat_input):
61
  text = chat_input["text"]
62
  images = chat_input["files"]
63
- messages = [
64
- {
65
- "role": "user",
66
- "content": text
67
- }
68
- ]
69
  if images:
70
- messages.extend([
71
- {
72
- "role": "user",
73
- "content": {"path": image}
74
- }
75
- for image in images
76
- ])
77
  return messages
78
 
 
79
  def build_messages_for_agent(chat_input, past_messages):
80
  # filter out image messages from past messages to save on tokens
81
  messages = past_messages
82
-
83
  # add the user's text message
84
  if chat_input["text"]:
85
- messages.append({
86
- "type": "text",
87
- "text": chat_input["text"]
88
- })
89
 
90
  # add the user's image message
91
  files = chat_input.get("files", [])
92
  image_url = upload_image(files[0]) if files else None
93
  if image_url:
94
- messages.append({
95
- "type": "image_url",
96
- "image_url": {"url": image_url}
97
- })
98
 
99
  return messages
100
-
 
101
  def select_example(x: gr.SelectData, chat_input):
102
  chat_input["text"] = x.value["text"]
103
  chat_input["files"] = x.value["files"]
104
  return chat_input
105
 
 
106
  async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
107
  # Prepare messages for the UI
108
  chatbot.extend(build_user_message(chat_input))
@@ -113,15 +99,10 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
113
  files = chat_input.get("files", [])
114
  image_url = upload_image(files[0]) if files else None
115
  messages = [
116
- {
117
- "type": "text",
118
- "text": text
119
- },
120
  ]
121
  if image_url:
122
- messages.append(
123
- {"type": "image_url", "image_url": {"url": image_url}}
124
- )
125
  current_image = image_url
126
 
127
  # Dependencies
@@ -131,66 +112,67 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
131
  edit_instruction=text,
132
  image_url=current_image,
133
  hopter_client=hopter,
134
- mask_service=mask_service
135
  )
136
 
137
  # Run the agent
138
  async with image_edit_agent.run_stream(
139
- messages,
140
- deps=deps,
141
- message_history=past_messages
142
  ) as result:
143
  for message in result.new_messages():
144
  for call in message.parts:
145
  if isinstance(call, ToolCallPart):
146
  call_args = (
147
  call.args.args_json
148
- if hasattr(call.args, 'args_json')
149
  else call.args
150
  )
151
  metadata = {
152
- 'title': f'🛠️ Using {call.tool_name}',
153
  }
154
  # set the tool call id so that when the tool returns
155
  # we can find this message and update with the result
156
  if call.tool_call_id is not None:
157
- metadata['id'] = call.tool_call_id
158
 
159
  # Create a tool call message to show on the UI
160
  gr_message = {
161
- 'role': 'assistant',
162
- 'content': 'Parameters: ' + call_args,
163
- 'metadata': metadata,
164
  }
165
  chatbot.append(gr_message)
166
  if isinstance(call, ToolReturnPart):
167
  for gr_message in chatbot:
168
  # Skip messages without metadata
169
- if not gr_message.get('metadata'):
170
  continue
171
 
172
- if gr_message['metadata'].get('id', '') == call.tool_call_id:
173
  if isinstance(call.content, EditImageResult):
174
- chatbot.append({
175
- "role": "assistant",
176
- "content": gr.Image(call.content.edited_image_url),
177
- "files": [call.content.edited_image_url]
178
- })
 
 
 
 
179
  current_image = call.content.edited_image_url
180
  else:
181
- gr_message['content'] += (
182
- f'\nOutput: {call.content}'
183
- )
184
  yield gr.skip(), chatbot, gr.skip(), gr.skip()
185
 
186
- chatbot.append({'role': 'assistant', 'content': ''})
187
  async for message in result.stream_text():
188
- chatbot[-1]['content'] = message
189
  yield gr.skip(), chatbot, gr.skip(), gr.skip()
190
  past_messages = result.all_messages()
191
 
192
  yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
193
 
 
194
  with gr.Blocks() as demo:
195
  gr.Markdown(INTRO)
196
 
@@ -198,10 +180,10 @@ with gr.Blocks() as demo:
198
  past_messages = gr.State([])
199
  chatbot = gr.Chatbot(
200
  elem_id="chatbot",
201
- label='Image Editing Assistant',
202
- type='messages',
203
- avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'),
204
- examples=EXAMPLES
205
  )
206
 
207
  with gr.Row():
@@ -209,8 +191,8 @@ with gr.Blocks() as demo:
209
  interactive=True,
210
  file_count="single",
211
  show_label=False,
212
- placeholder='How would you like to edit this image?',
213
- sources=["upload"]
214
  )
215
  generation = chat_input.submit(
216
  stream_from_agent,
@@ -233,7 +215,7 @@ with gr.Blocks() as demo:
233
  inputs=[chat_input],
234
  outputs=[chat_input],
235
  )
236
-
237
 
238
- if __name__ == '__main__':
239
- demo.launch()
 
 
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
  from src.utils import upload_image
8
+ from pydantic_ai.messages import ToolCallPart, ToolReturnPart
 
 
 
9
  from pydantic_ai.models.openai import OpenAIModel
10
 
11
+ model = OpenAIModel(
12
  "gpt-4o",
13
  api_key=os.environ.get("OPENAI_API_KEY"),
14
  )
 
30
  "text": "Replace the background to the space with stars and planets",
31
  "files": [
32
  "https://cdn.prod.website-files.com/66f230993926deadc0ac3a44/66f370d65f158cbbcfbcc532_Crossed%20Arms%20Levi%20Meir%20Clancy.jpg"
33
+ ],
34
  },
35
  {
36
  "text": "Change all the balloons to red in the image",
37
  "files": [
38
  "https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
39
+ ],
40
  },
41
  {
42
  "text": "Change coffee to a glass of water",
43
  "files": [
44
  "https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
45
+ ],
46
  },
47
  {
48
  "text": "ENHANCE!",
49
  "files": [
50
  "https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
51
+ ],
52
+ },
53
  ]
54
 
55
  load_dotenv()
56
 
57
+
58
  def build_user_message(chat_input):
59
  text = chat_input["text"]
60
  images = chat_input["files"]
61
+ messages = [{"role": "user", "content": text}]
 
 
 
 
 
62
  if images:
63
+ messages.extend(
64
+ [{"role": "user", "content": {"path": image}} for image in images]
65
+ )
 
 
 
 
66
  return messages
67
 
68
+
69
  def build_messages_for_agent(chat_input, past_messages):
70
  # filter out image messages from past messages to save on tokens
71
  messages = past_messages
72
+
73
  # add the user's text message
74
  if chat_input["text"]:
75
+ messages.append({"type": "text", "text": chat_input["text"]})
 
 
 
76
 
77
  # add the user's image message
78
  files = chat_input.get("files", [])
79
  image_url = upload_image(files[0]) if files else None
80
  if image_url:
81
+ messages.append({"type": "image_url", "image_url": {"url": image_url}})
 
 
 
82
 
83
  return messages
84
+
85
+
86
  def select_example(x: gr.SelectData, chat_input):
87
  chat_input["text"] = x.value["text"]
88
  chat_input["files"] = x.value["files"]
89
  return chat_input
90
 
91
+
92
  async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
93
  # Prepare messages for the UI
94
  chatbot.extend(build_user_message(chat_input))
 
99
  files = chat_input.get("files", [])
100
  image_url = upload_image(files[0]) if files else None
101
  messages = [
102
+ {"type": "text", "text": text},
 
 
 
103
  ]
104
  if image_url:
105
+ messages.append({"type": "image_url", "image_url": {"url": image_url}})
 
 
106
  current_image = image_url
107
 
108
  # Dependencies
 
112
  edit_instruction=text,
113
  image_url=current_image,
114
  hopter_client=hopter,
115
+ mask_service=mask_service,
116
  )
117
 
118
  # Run the agent
119
  async with image_edit_agent.run_stream(
120
+ messages, deps=deps, message_history=past_messages
 
 
121
  ) as result:
122
  for message in result.new_messages():
123
  for call in message.parts:
124
  if isinstance(call, ToolCallPart):
125
  call_args = (
126
  call.args.args_json
127
+ if hasattr(call.args, "args_json")
128
  else call.args
129
  )
130
  metadata = {
131
+ "title": f"🛠️ Using {call.tool_name}",
132
  }
133
  # set the tool call id so that when the tool returns
134
  # we can find this message and update with the result
135
  if call.tool_call_id is not None:
136
+ metadata["id"] = call.tool_call_id
137
 
138
  # Create a tool call message to show on the UI
139
  gr_message = {
140
+ "role": "assistant",
141
+ "content": "Parameters: " + call_args,
142
+ "metadata": metadata,
143
  }
144
  chatbot.append(gr_message)
145
  if isinstance(call, ToolReturnPart):
146
  for gr_message in chatbot:
147
  # Skip messages without metadata
148
+ if not gr_message.get("metadata"):
149
  continue
150
 
151
+ if gr_message["metadata"].get("id", "") == call.tool_call_id:
152
  if isinstance(call.content, EditImageResult):
153
+ chatbot.append(
154
+ {
155
+ "role": "assistant",
156
+ "content": gr.Image(
157
+ call.content.edited_image_url
158
+ ),
159
+ "files": [call.content.edited_image_url],
160
+ }
161
+ )
162
  current_image = call.content.edited_image_url
163
  else:
164
+ gr_message["content"] += f"\nOutput: {call.content}"
 
 
165
  yield gr.skip(), chatbot, gr.skip(), gr.skip()
166
 
167
+ chatbot.append({"role": "assistant", "content": ""})
168
  async for message in result.stream_text():
169
+ chatbot[-1]["content"] = message
170
  yield gr.skip(), chatbot, gr.skip(), gr.skip()
171
  past_messages = result.all_messages()
172
 
173
  yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
174
 
175
+
176
  with gr.Blocks() as demo:
177
  gr.Markdown(INTRO)
178
 
 
180
  past_messages = gr.State([])
181
  chatbot = gr.Chatbot(
182
  elem_id="chatbot",
183
+ label="Image Editing Assistant",
184
+ type="messages",
185
+ avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
186
+ examples=EXAMPLES,
187
  )
188
 
189
  with gr.Row():
 
191
  interactive=True,
192
  file_count="single",
193
  show_label=False,
194
+ placeholder="How would you like to edit this image?",
195
+ sources=["upload"],
196
  )
197
  generation = chat_input.submit(
198
  stream_from_agent,
 
215
  inputs=[chat_input],
216
  outputs=[chat_input],
217
  )
 
218
 
219
+
220
+ if __name__ == "__main__":
221
+ demo.launch()
image_edit_demo.py CHANGED
@@ -4,50 +4,47 @@ import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
- from pydantic_ai.messages import (
8
- ToolReturnPart
9
- )
10
  from src.utils import upload_image
 
11
  load_dotenv()
12
 
 
13
  async def process_edit(image, instruction):
14
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
15
  mask_service = GenerateMaskService(hopter=hopter)
16
  image_url = upload_image(image)
17
  messages = [
18
- {
19
- "type": "text",
20
- "text": instruction
21
- },
22
  ]
23
  if image:
24
- messages.append(
25
- {"type": "image_url", "image_url": {"url": image_url}}
26
- )
27
  deps = ImageEditDeps(
28
  edit_instruction=instruction,
29
  image_url=image_url,
30
  hopter_client=hopter,
31
- mask_service=mask_service
32
- )
33
- result = await image_edit_agent.run(
34
- messages,
35
- deps=deps
36
  )
 
37
  # Extract the edited image URL from the tool return
38
  for message in result.new_messages():
39
  for part in message.parts:
40
- if isinstance(part, ToolReturnPart) and isinstance(part.content, EditImageResult):
 
 
41
  return part.content.edited_image_url
42
  return None
43
 
 
44
  async def use_edited_image(edited_image):
45
  return edited_image
46
 
 
47
  def clear_instruction():
48
  # Only clear the instruction text.
49
  return ""
50
 
 
51
  # Create the Gradio interface
52
  with gr.Blocks() as demo:
53
  gr.Markdown("# PicEdit")
@@ -55,57 +52,52 @@ with gr.Blocks() as demo:
55
  Welcome to PicEdit - an AI-powered image editing tool.
56
  Simply upload an image and describe the changes you want to make in natural language.
57
  """)
58
-
59
  with gr.Row():
60
  # Input image on the left
61
  input_image = gr.Image(label="Original Image", type="filepath")
62
-
63
  with gr.Column():
64
  # Output image on the right
65
- output_image = gr.Image(label="Edited Image", type="filepath", interactive=False, scale=3)
 
 
66
  use_edited_btn = gr.Button("👈 Use Edited Image 👈")
67
 
68
  # Text input for editing instructions
69
  instruction = gr.Textbox(
70
  label="Editing Instructions",
71
- placeholder="Describe the changes you want to make to the image..."
72
  )
73
-
74
  # Clear button
75
  with gr.Row():
76
  clear_btn = gr.Button("Clear")
77
  submit_btn = gr.Button("Apply Edit", variant="primary")
78
-
79
  # Set up the event handlers
80
  submit_btn.click(
81
- fn=process_edit,
82
- inputs=[input_image, instruction],
83
- outputs=output_image
84
  )
85
-
86
  use_edited_btn.click(
87
- fn=use_edited_image,
88
- inputs=[output_image],
89
- outputs=[input_image]
90
  )
91
 
92
  # Bind the clear button's click event to only clear the instruction textbox.
93
- clear_btn.click(
94
- fn=clear_instruction,
95
- inputs=[],
96
- outputs=[instruction]
97
- )
98
 
99
  examples = gr.Examples(
100
  examples=[
101
  ["https://i.ibb.co/qYwhcc6j/c837c212afbf.jpg", "remove the pole"],
102
  ["https://i.ibb.co/2Mrxztw/image.png", "replace the cat with a dog"],
103
- ["https://i.ibb.co/9mT4cvnt/resized-78-B40-C09-1037-4-DD3-9-F48-D73637-EE4-E51.png", "ENHANCE!"]
 
 
 
104
  ],
105
- inputs=[input_image, instruction]
106
  )
107
 
108
  if __name__ == "__main__":
109
  demo.launch()
110
-
111
-
 
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
6
  from dotenv import load_dotenv
7
+ from pydantic_ai.messages import ToolReturnPart
 
 
8
  from src.utils import upload_image
9
+
10
  load_dotenv()
11
 
12
+
13
  async def process_edit(image, instruction):
14
  hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
15
  mask_service = GenerateMaskService(hopter=hopter)
16
  image_url = upload_image(image)
17
  messages = [
18
+ {"type": "text", "text": instruction},
 
 
 
19
  ]
20
  if image:
21
+ messages.append({"type": "image_url", "image_url": {"url": image_url}})
 
 
22
  deps = ImageEditDeps(
23
  edit_instruction=instruction,
24
  image_url=image_url,
25
  hopter_client=hopter,
26
+ mask_service=mask_service,
 
 
 
 
27
  )
28
+ result = await image_edit_agent.run(messages, deps=deps)
29
  # Extract the edited image URL from the tool return
30
  for message in result.new_messages():
31
  for part in message.parts:
32
+ if isinstance(part, ToolReturnPart) and isinstance(
33
+ part.content, EditImageResult
34
+ ):
35
  return part.content.edited_image_url
36
  return None
37
 
38
+
39
  async def use_edited_image(edited_image):
40
  return edited_image
41
 
42
+
43
  def clear_instruction():
44
  # Only clear the instruction text.
45
  return ""
46
 
47
+
48
  # Create the Gradio interface
49
  with gr.Blocks() as demo:
50
  gr.Markdown("# PicEdit")
 
52
  Welcome to PicEdit - an AI-powered image editing tool.
53
  Simply upload an image and describe the changes you want to make in natural language.
54
  """)
55
+
56
  with gr.Row():
57
  # Input image on the left
58
  input_image = gr.Image(label="Original Image", type="filepath")
59
+
60
  with gr.Column():
61
  # Output image on the right
62
+ output_image = gr.Image(
63
+ label="Edited Image", type="filepath", interactive=False, scale=3
64
+ )
65
  use_edited_btn = gr.Button("👈 Use Edited Image 👈")
66
 
67
  # Text input for editing instructions
68
  instruction = gr.Textbox(
69
  label="Editing Instructions",
70
+ placeholder="Describe the changes you want to make to the image...",
71
  )
72
+
73
  # Clear button
74
  with gr.Row():
75
  clear_btn = gr.Button("Clear")
76
  submit_btn = gr.Button("Apply Edit", variant="primary")
77
+
78
  # Set up the event handlers
79
  submit_btn.click(
80
+ fn=process_edit, inputs=[input_image, instruction], outputs=output_image
 
 
81
  )
82
+
83
  use_edited_btn.click(
84
+ fn=use_edited_image, inputs=[output_image], outputs=[input_image]
 
 
85
  )
86
 
87
  # Bind the clear button's click event to only clear the instruction textbox.
88
+ clear_btn.click(fn=clear_instruction, inputs=[], outputs=[instruction])
 
 
 
 
89
 
90
  examples = gr.Examples(
91
  examples=[
92
  ["https://i.ibb.co/qYwhcc6j/c837c212afbf.jpg", "remove the pole"],
93
  ["https://i.ibb.co/2Mrxztw/image.png", "replace the cat with a dog"],
94
+ [
95
+ "https://i.ibb.co/9mT4cvnt/resized-78-B40-C09-1037-4-DD3-9-F48-D73637-EE4-E51.png",
96
+ "ENHANCE!",
97
+ ],
98
  ],
99
+ inputs=[input_image, instruction],
100
  )
101
 
102
  if __name__ == "__main__":
103
  demo.launch()
 
 
server.py CHANGED
@@ -1,10 +1,9 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
- import asyncio
5
  import os
6
  from dotenv import load_dotenv
7
- from typing import Optional, List, Dict, Any
8
  import json
9
  from pydantic import BaseModel
10
 
@@ -13,7 +12,7 @@ from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps
13
  from src.agents.generic_agent import generic_agent
14
  from src.hopter.client import Hopter, Environment
15
  from src.services.generate_mask import GenerateMaskService
16
- from src.utils import upload_file_to_base64, upload_image
17
 
18
  # Load environment variables
19
  load_dotenv()
@@ -29,15 +28,18 @@ app.add_middleware(
29
  allow_headers=["*"], # Allows all headers
30
  )
31
 
 
32
  class EditRequest(BaseModel):
33
  edit_instruction: str
34
  image_url: Optional[str] = None
35
 
 
36
  class MessageContent(BaseModel):
37
  type: str
38
  text: Optional[str] = None
39
  image_url: Optional[Dict[str, str]] = None
40
 
 
41
  class Message(BaseModel):
42
  content: List[MessageContent]
43
 
@@ -48,9 +50,10 @@ async def test(query: str):
48
  async with generic_agent.run_stream(query) as result:
49
  async for message in result.stream(debounce_by=0.01):
50
  yield json.dumps(message) + "\n"
51
-
52
  return StreamingResponse(stream_messages(), media_type="text/plain")
53
 
 
54
  @app.post("/edit")
55
  async def edit_image(request: EditRequest):
56
  """
@@ -60,8 +63,7 @@ async def edit_image(request: EditRequest):
60
  try:
61
  # Initialize services
62
  hopter = Hopter(
63
- api_key=os.environ.get("HOPTER_API_KEY"),
64
- environment=Environment.STAGING
65
  )
66
  mask_service = GenerateMaskService(hopter=hopter)
67
 
@@ -70,34 +72,27 @@ async def edit_image(request: EditRequest):
70
  edit_instruction=request.edit_instruction,
71
  image_url=request.image_url,
72
  hopter_client=hopter,
73
- mask_service=mask_service
74
  )
75
 
76
  # Create messages
77
- messages = [
78
- {
79
- "type": "text",
80
- "text": request.edit_instruction
81
- }
82
- ]
83
-
84
  if request.image_url:
85
- messages.append({
86
- "type": "image_url",
87
- "image_url": {
88
- "url": request.image_url
89
- }
90
- })
91
 
92
  # Run the agent
93
  result = await image_edit_agent.run(messages, deps=deps)
94
-
95
  # Return the result
96
  return {"edited_image_url": result.edited_image_url}
97
-
98
  except Exception as e:
99
  raise HTTPException(status_code=500, detail=str(e))
100
 
 
101
  @app.post("/edit/stream")
102
  async def edit_image_stream(request: EditRequest):
103
  """
@@ -107,8 +102,7 @@ async def edit_image_stream(request: EditRequest):
107
  try:
108
  # Initialize services
109
  hopter = Hopter(
110
- api_key=os.environ.get("HOPTER_API_KEY"),
111
- environment=Environment.STAGING
112
  )
113
  mask_service = GenerateMaskService(hopter=hopter)
114
 
@@ -117,24 +111,16 @@ async def edit_image_stream(request: EditRequest):
117
  edit_instruction=request.edit_instruction,
118
  image_url=request.image_url,
119
  hopter_client=hopter,
120
- mask_service=mask_service
121
  )
122
 
123
  # Create messages
124
- messages = [
125
- {
126
- "type": "text",
127
- "text": request.edit_instruction
128
- }
129
- ]
130
-
131
  if request.image_url:
132
- messages.append({
133
- "type": "image_url",
134
- "image_url": {
135
- "url": request.image_url
136
- }
137
- })
138
 
139
  async def stream_generator():
140
  async with image_edit_agent.run_stream(messages, deps=deps) as result:
@@ -142,14 +128,12 @@ async def edit_image_stream(request: EditRequest):
142
  # Convert message to JSON and yield
143
  yield json.dumps(message) + "\n"
144
 
145
- return StreamingResponse(
146
- stream_generator(),
147
- media_type="application/x-ndjson"
148
- )
149
-
150
  except Exception as e:
151
  raise HTTPException(status_code=500, detail=str(e))
152
 
 
153
  @app.post("/upload")
154
  async def upload_image_file(file: UploadFile = File(...)):
155
  """
@@ -160,18 +144,19 @@ async def upload_image_file(file: UploadFile = File(...)):
160
  temp_file_path = f"/tmp/{file.filename}"
161
  with open(temp_file_path, "wb") as buffer:
162
  buffer.write(await file.read())
163
-
164
  # Upload the image to Google Cloud Storage
165
  image_url = upload_image(temp_file_path)
166
-
167
  # Remove the temporary file
168
  os.remove(temp_file_path)
169
-
170
  return {"image_url": image_url}
171
-
172
  except Exception as e:
173
  raise HTTPException(status_code=500, detail=str(e))
174
 
 
175
  @app.get("/health")
176
  async def health_check():
177
  """
@@ -179,6 +164,8 @@ async def health_check():
179
  """
180
  return {"status": "ok"}
181
 
 
182
  if __name__ == "__main__":
183
  import uvicorn
184
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  import os
5
  from dotenv import load_dotenv
6
+ from typing import Optional, List, Dict
7
  import json
8
  from pydantic import BaseModel
9
 
 
12
  from src.agents.generic_agent import generic_agent
13
  from src.hopter.client import Hopter, Environment
14
  from src.services.generate_mask import GenerateMaskService
15
+ from src.utils import upload_image
16
 
17
  # Load environment variables
18
  load_dotenv()
 
28
  allow_headers=["*"], # Allows all headers
29
  )
30
 
31
+
32
  class EditRequest(BaseModel):
33
  edit_instruction: str
34
  image_url: Optional[str] = None
35
 
36
+
37
  class MessageContent(BaseModel):
38
  type: str
39
  text: Optional[str] = None
40
  image_url: Optional[Dict[str, str]] = None
41
 
42
+
43
  class Message(BaseModel):
44
  content: List[MessageContent]
45
 
 
50
  async with generic_agent.run_stream(query) as result:
51
  async for message in result.stream(debounce_by=0.01):
52
  yield json.dumps(message) + "\n"
53
+
54
  return StreamingResponse(stream_messages(), media_type="text/plain")
55
 
56
+
57
  @app.post("/edit")
58
  async def edit_image(request: EditRequest):
59
  """
 
63
  try:
64
  # Initialize services
65
  hopter = Hopter(
66
+ api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
 
67
  )
68
  mask_service = GenerateMaskService(hopter=hopter)
69
 
 
72
  edit_instruction=request.edit_instruction,
73
  image_url=request.image_url,
74
  hopter_client=hopter,
75
+ mask_service=mask_service,
76
  )
77
 
78
  # Create messages
79
+ messages = [{"type": "text", "text": request.edit_instruction}]
80
+
 
 
 
 
 
81
  if request.image_url:
82
+ messages.append(
83
+ {"type": "image_url", "image_url": {"url": request.image_url}}
84
+ )
 
 
 
85
 
86
  # Run the agent
87
  result = await image_edit_agent.run(messages, deps=deps)
88
+
89
  # Return the result
90
  return {"edited_image_url": result.edited_image_url}
91
+
92
  except Exception as e:
93
  raise HTTPException(status_code=500, detail=str(e))
94
 
95
+
96
  @app.post("/edit/stream")
97
  async def edit_image_stream(request: EditRequest):
98
  """
 
102
  try:
103
  # Initialize services
104
  hopter = Hopter(
105
+ api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
 
106
  )
107
  mask_service = GenerateMaskService(hopter=hopter)
108
 
 
111
  edit_instruction=request.edit_instruction,
112
  image_url=request.image_url,
113
  hopter_client=hopter,
114
+ mask_service=mask_service,
115
  )
116
 
117
  # Create messages
118
+ messages = [{"type": "text", "text": request.edit_instruction}]
119
+
 
 
 
 
 
120
  if request.image_url:
121
+ messages.append(
122
+ {"type": "image_url", "image_url": {"url": request.image_url}}
123
+ )
 
 
 
124
 
125
  async def stream_generator():
126
  async with image_edit_agent.run_stream(messages, deps=deps) as result:
 
128
  # Convert message to JSON and yield
129
  yield json.dumps(message) + "\n"
130
 
131
+ return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
132
+
 
 
 
133
  except Exception as e:
134
  raise HTTPException(status_code=500, detail=str(e))
135
 
136
+
137
  @app.post("/upload")
138
  async def upload_image_file(file: UploadFile = File(...)):
139
  """
 
144
  temp_file_path = f"/tmp/{file.filename}"
145
  with open(temp_file_path, "wb") as buffer:
146
  buffer.write(await file.read())
147
+
148
  # Upload the image to Google Cloud Storage
149
  image_url = upload_image(temp_file_path)
150
+
151
  # Remove the temporary file
152
  os.remove(temp_file_path)
153
+
154
  return {"image_url": image_url}
155
+
156
  except Exception as e:
157
  raise HTTPException(status_code=500, detail=str(e))
158
 
159
+
160
  @app.get("/health")
161
  async def health_check():
162
  """
 
164
  """
165
  return {"status": "ok"}
166
 
167
+
168
  if __name__ == "__main__":
169
  import uvicorn
170
+
171
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/agents/generic_agent.py CHANGED
@@ -1,14 +1,11 @@
1
- from pydantic_ai import Agent, RunContext
2
  from pydantic_ai.models.openai import OpenAIModel
3
  from dotenv import load_dotenv
4
  import os
5
 
6
  load_dotenv()
7
 
8
- model = OpenAIModel(
9
- "gpt-4o",
10
- api_key=os.environ.get("OPENAI_API_KEY")
11
- )
12
 
13
  system_prompt = """
14
  You are a helpful assistant that can answer questions and help with tasks.
@@ -19,5 +16,3 @@ generic_agent = Agent(
19
  system_prompt=system_prompt,
20
  tools=[],
21
  )
22
-
23
-
 
1
+ from pydantic_ai import Agent
2
  from pydantic_ai.models.openai import OpenAIModel
3
  from dotenv import load_dotenv
4
  import os
5
 
6
  load_dotenv()
7
 
8
+ model = OpenAIModel("gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
 
 
 
9
 
10
  system_prompt = """
11
  You are a helpful assistant that can answer questions and help with tasks.
 
16
  system_prompt=system_prompt,
17
  tools=[],
18
  )
 
 
src/agents/image_edit_agent.py CHANGED
@@ -7,7 +7,12 @@ from dataclasses import dataclass
7
  from typing import Optional
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
- from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
 
 
 
 
 
11
  from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
12
  import base64
13
  import tempfile
@@ -23,6 +28,7 @@ if the edit instruction involved modifying parts of the image, please generate a
23
  if images are not provided, ask the user to provide an image.
24
  """
25
 
 
26
  @dataclass
27
  class ImageEditDeps:
28
  edit_instruction: str
@@ -30,6 +36,7 @@ class ImageEditDeps:
30
  mask_service: GenerateMaskService
31
  image_url: Optional[str] = None
32
 
 
33
  model = OpenAIModel(
34
  "gpt-4o",
35
  api_key=os.environ.get("OPENAI_API_KEY"),
@@ -40,11 +47,9 @@ model = OpenAIModel(
40
  class EditImageResult:
41
  edited_image_url: str
42
 
43
- image_edit_agent = Agent(
44
- model,
45
- system_prompt=system_prompt,
46
- deps_type=ImageEditDeps
47
- )
48
 
49
  def upload_image_from_base64(base64_image: str) -> str:
50
  image_format = base64_image.split(",")[0]
@@ -56,6 +61,7 @@ def upload_image_from_base64(base64_image: str) -> str:
56
  f.write(image_data)
57
  return upload_image(temp_filename)
58
 
 
59
  @image_edit_agent.tool
60
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
61
  """
@@ -75,15 +81,20 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
75
  image_uri = download_image_to_data_uri(image_url)
76
 
77
  # Generate mask
78
- mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
 
 
79
  mask = mask_service.generate_mask(mask_instruction, image_uri)
80
 
81
  # Magic replace
82
- input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
 
 
83
  result = hopter_client.magic_replace(input)
84
  uploaded_image = upload_image_from_base64(result.base64_image)
85
  return EditImageResult(edited_image_url=uploaded_image)
86
 
 
87
  @image_edit_agent.tool
88
  async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
89
  """
@@ -94,31 +105,28 @@ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
94
 
95
  image_uri = download_image_to_data_uri(image_url)
96
 
97
- input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
 
 
98
  result = hopter_client.super_resolution(input)
99
  uploaded_image = upload_image_from_base64(result.scaled_image)
100
  return EditImageResult(edited_image_url=uploaded_image)
101
 
 
102
  async def main():
103
  image_file_path = "./assets/lakeview.jpg"
104
  image_url = image_path_to_uri(image_file_path)
105
 
106
  prompt = "remove the light post"
107
  messages = [
108
- {
109
- "type": "text",
110
- "text": prompt
111
- },
112
- {
113
- "type": "image_url",
114
- "image_url": {
115
- "url": image_url
116
- }
117
- }
118
  ]
119
 
120
  # Initialize services
121
- hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
 
 
122
  mask_service = GenerateMaskService(hopter=hopter)
123
 
124
  # Initialize dependencies
@@ -126,15 +134,12 @@ async def main():
126
  edit_instruction=prompt,
127
  image_url=image_url,
128
  hopter_client=hopter,
129
- mask_service=mask_service
130
  )
131
- async with image_edit_agent.run_stream(
132
- messages,
133
- deps=deps
134
- ) as result:
135
  async for message in result.stream():
136
  print(message)
137
 
138
 
139
  if __name__ == "__main__":
140
- asyncio.run(main())
 
7
  from typing import Optional
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
+ from src.hopter.client import (
11
+ Hopter,
12
+ Environment,
13
+ MagicReplaceInput,
14
+ SuperResolutionInput,
15
+ )
16
  from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
17
  import base64
18
  import tempfile
 
28
  if images are not provided, ask the user to provide an image.
29
  """
30
 
31
+
32
  @dataclass
33
  class ImageEditDeps:
34
  edit_instruction: str
 
36
  mask_service: GenerateMaskService
37
  image_url: Optional[str] = None
38
 
39
+
40
  model = OpenAIModel(
41
  "gpt-4o",
42
  api_key=os.environ.get("OPENAI_API_KEY"),
 
47
  class EditImageResult:
48
  edited_image_url: str
49
 
50
+
51
+ image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps)
52
+
 
 
53
 
54
  def upload_image_from_base64(base64_image: str) -> str:
55
  image_format = base64_image.split(",")[0]
 
61
  f.write(image_data)
62
  return upload_image(temp_filename)
63
 
64
+
65
  @image_edit_agent.tool
66
  async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
67
  """
 
81
  image_uri = download_image_to_data_uri(image_url)
82
 
83
  # Generate mask
84
+ mask_instruction = mask_service.get_mask_generation_instruction(
85
+ edit_instruction, image_url
86
+ )
87
  mask = mask_service.generate_mask(mask_instruction, image_uri)
88
 
89
  # Magic replace
90
+ input = MagicReplaceInput(
91
+ image=image_uri, mask=mask, prompt=mask_instruction.target_caption
92
+ )
93
  result = hopter_client.magic_replace(input)
94
  uploaded_image = upload_image_from_base64(result.base64_image)
95
  return EditImageResult(edited_image_url=uploaded_image)
96
 
97
+
98
  @image_edit_agent.tool
99
  async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
100
  """
 
105
 
106
  image_uri = download_image_to_data_uri(image_url)
107
 
108
+ input = SuperResolutionInput(
109
+ image_b64=image_uri, scale=4, use_face_enhancement=False
110
+ )
111
  result = hopter_client.super_resolution(input)
112
  uploaded_image = upload_image_from_base64(result.scaled_image)
113
  return EditImageResult(edited_image_url=uploaded_image)
114
 
115
+
116
  async def main():
117
  image_file_path = "./assets/lakeview.jpg"
118
  image_url = image_path_to_uri(image_file_path)
119
 
120
  prompt = "remove the light post"
121
  messages = [
122
+ {"type": "text", "text": prompt},
123
+ {"type": "image_url", "image_url": {"url": image_url}},
 
 
 
 
 
 
 
 
124
  ]
125
 
126
  # Initialize services
127
+ hopter = Hopter(
128
+ api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
129
+ )
130
  mask_service = GenerateMaskService(hopter=hopter)
131
 
132
  # Initialize dependencies
 
134
  edit_instruction=prompt,
135
  image_url=image_url,
136
  hopter_client=hopter,
137
+ mask_service=mask_service,
138
  )
139
+ async with image_edit_agent.run_stream(messages, deps=deps) as result:
 
 
 
140
  async for message in result.stream():
141
  print(message)
142
 
143
 
144
  if __name__ == "__main__":
145
+ asyncio.run(main())
src/agents/mask_generation_agent.py CHANGED
@@ -1,16 +1,9 @@
1
- from pydantic_ai import Agent, RunContext
2
  from pydantic_ai.models.openai import OpenAIModel
3
  from dotenv import load_dotenv
4
  import os
5
- import asyncio
6
  from dataclasses import dataclass
7
- from typing import Optional
8
  import logfire
9
- from src.services.generate_mask import GenerateMaskService
10
- from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
11
- from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
12
- import base64
13
- import tempfile
14
 
15
  load_dotenv()
16
 
@@ -56,10 +49,9 @@ model = OpenAIModel(
56
  class MaskGenerationResult:
57
  mask_image_base64: str
58
 
59
- mask_generation_agent = Agent(
60
- model,
61
- system_prompt=system_prompt
62
- )
63
 
64
  @mask_generation_agent.tool
65
  async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
 
1
+ from pydantic_ai import Agent
2
  from pydantic_ai.models.openai import OpenAIModel
3
  from dotenv import load_dotenv
4
  import os
 
5
  from dataclasses import dataclass
 
6
  import logfire
 
 
 
 
 
7
 
8
  load_dotenv()
9
 
 
49
  class MaskGenerationResult:
50
  mask_image_base64: str
51
 
52
+
53
+ mask_generation_agent = Agent(model, system_prompt=system_prompt)
54
+
 
55
 
56
  @mask_generation_agent.tool
57
  async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
src/hopter/client.py CHANGED
@@ -9,6 +9,7 @@ from typing import List
9
 
10
  load_dotenv()
11
 
 
12
  class Environment(Enum):
13
  STAGING = "staging"
14
  PRODUCTION = "production"
@@ -20,39 +21,50 @@ class Environment(Enum):
20
  return "https://serving.hopter.staging.picc.co"
21
  case Environment.PRODUCTION:
22
  return "https://serving.hopter.picc.co"
23
-
 
24
  class RamGroundedSamInput(BaseModel):
25
- text_prompt: str = Field(..., description="The text prompt for the mask generation.")
 
 
26
  image_b64: str = Field(..., description="The image in base64 format.")
27
 
 
28
  class RamGroundedSamResult(BaseModel):
29
  mask_b64: str = Field(..., description="The mask image in base64 format.")
30
  class_label: str = Field(..., description="The class label of the mask.")
31
  confidence: float = Field(..., description="The confidence score of the mask.")
32
- bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
 
 
 
33
 
34
  class MagicReplaceInput(BaseModel):
35
  image: str = Field(..., description="The image in base64 format.")
36
  mask: str = Field(..., description="The mask in base64 format.")
37
  prompt: str = Field(..., description="The prompt for the magic replace.")
38
 
 
39
  class MagicReplaceResult(BaseModel):
40
  base64_image: str = Field(..., description="The edited image in base64 format.")
41
 
 
42
  class SuperResolutionInput(BaseModel):
43
  image_b64: str = Field(..., description="The image in base64 format.")
44
  scale: int = Field(4, description="The scale of the image to upscale to.")
45
- use_face_enhancement: bool = Field(False, description="Whether to use face enhancement.")
 
 
 
46
 
47
  class SuperResolutionResult(BaseModel):
48
- scaled_image: str = Field(..., description="The super-resolved image in base64 format.")
 
 
 
49
 
50
  class Hopter:
51
- def __init__(
52
- self,
53
- api_key: str,
54
- environment: Environment = Environment.PRODUCTION
55
- ):
56
  self.api_key = api_key
57
  self.base_url = environment.base_url
58
  self.client = httpx.Client()
@@ -64,22 +76,22 @@ class Hopter:
64
  f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
65
  headers={
66
  "Authorization": f"Bearer {self.api_key}",
67
- "Content-Type": "application/json"
68
- },
69
- json={
70
- "input": input.model_dump()
71
  },
72
- timeout=None
 
73
  )
74
  response.raise_for_status() # Raise an error for bad responses
75
  instance = response.json().get("output").get("instances")[0]
76
  print("Generated mask.")
77
  return RamGroundedSamResult(**instance)
78
  except httpx.HTTPStatusError as exc:
79
- print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
 
 
80
  except Exception as exc:
81
  print(f"An unexpected error occurred: {exc}")
82
-
83
  def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
84
  print(f"Magic replacing with input: {input.prompt}")
85
  try:
@@ -87,19 +99,19 @@ class Hopter:
87
  f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
88
  headers={
89
  "Authorization": f"Bearer {self.api_key}",
90
- "Content-Type": "application/json"
91
  },
92
- json={
93
- "input": input.model_dump()
94
- },
95
- timeout=None
96
  )
97
  response.raise_for_status() # Raise an error for bad responses
98
  instance = response.json().get("output")
99
  print("Magic replaced.")
100
  return MagicReplaceResult(**instance)
101
  except httpx.HTTPStatusError as exc:
102
- print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
 
 
103
  except Exception as exc:
104
  print(f"An unexpected error occurred: {exc}")
105
 
@@ -109,51 +121,50 @@ class Hopter:
109
  f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
110
  headers={
111
  "Authorization": f"Bearer {self.api_key}",
112
- "Content-Type": "application/json"
113
  },
114
- json={
115
- "input": input.model_dump()
116
- },
117
- timeout=None
118
  )
119
  response.raise_for_status() # Raise an error for bad responses
120
  instance = response.json().get("output")
121
  print("Super-resolutin done")
122
  return SuperResolutionResult(**instance)
123
  except httpx.HTTPStatusError as exc:
124
- print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
 
 
125
  except Exception as exc:
126
  print(f"An unexpected error occurred: {exc}")
127
 
128
 
129
  async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
130
- input = RamGroundedSamInput(
131
- text_prompt="pole",
132
- image_b64=image_url
133
- )
134
  mask = hopter.generate_mask(input)
135
  return mask.mask_b64
136
-
137
- async def test_magic_replace(hopter: Hopter, image_url: str, mask: str, prompt: str) -> str:
138
- input = MagicReplaceInput(
139
- image=image_url,
140
- mask=mask,
141
- prompt=prompt
142
- )
143
  result = hopter.magic_replace(input)
144
  return result.base64_image
145
 
 
146
  async def main():
147
  hopter = Hopter(
148
- api_key=os.getenv("HOPTER_API_KEY"),
149
- environment=Environment.STAGING
150
  )
151
  image_file_path = "./assets/lakeview.jpg"
152
  image_url = image_path_to_uri(image_file_path)
153
 
154
  mask = await test_generate_mask(hopter, image_url)
155
- magic_replace_result = await test_magic_replace(hopter, image_url, mask, "remove the pole")
 
 
156
  print(magic_replace_result)
157
 
 
158
  if __name__ == "__main__":
159
  asyncio.run(main())
 
9
 
10
  load_dotenv()
11
 
12
+
13
  class Environment(Enum):
14
  STAGING = "staging"
15
  PRODUCTION = "production"
 
21
  return "https://serving.hopter.staging.picc.co"
22
  case Environment.PRODUCTION:
23
  return "https://serving.hopter.picc.co"
24
+
25
+
26
  class RamGroundedSamInput(BaseModel):
27
+ text_prompt: str = Field(
28
+ ..., description="The text prompt for the mask generation."
29
+ )
30
  image_b64: str = Field(..., description="The image in base64 format.")
31
 
32
+
33
  class RamGroundedSamResult(BaseModel):
34
  mask_b64: str = Field(..., description="The mask image in base64 format.")
35
  class_label: str = Field(..., description="The class label of the mask.")
36
  confidence: float = Field(..., description="The confidence score of the mask.")
37
+ bbox: List[float] = Field(
38
+ ..., description="The bounding box of the mask in the format [x1, y1, x2, y2]."
39
+ )
40
+
41
 
42
  class MagicReplaceInput(BaseModel):
43
  image: str = Field(..., description="The image in base64 format.")
44
  mask: str = Field(..., description="The mask in base64 format.")
45
  prompt: str = Field(..., description="The prompt for the magic replace.")
46
 
47
+
48
  class MagicReplaceResult(BaseModel):
49
  base64_image: str = Field(..., description="The edited image in base64 format.")
50
 
51
+
52
  class SuperResolutionInput(BaseModel):
53
  image_b64: str = Field(..., description="The image in base64 format.")
54
  scale: int = Field(4, description="The scale of the image to upscale to.")
55
+ use_face_enhancement: bool = Field(
56
+ False, description="Whether to use face enhancement."
57
+ )
58
+
59
 
60
  class SuperResolutionResult(BaseModel):
61
+ scaled_image: str = Field(
62
+ ..., description="The super-resolved image in base64 format."
63
+ )
64
+
65
 
66
  class Hopter:
67
+ def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION):
 
 
 
 
68
  self.api_key = api_key
69
  self.base_url = environment.base_url
70
  self.client = httpx.Client()
 
76
  f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
77
  headers={
78
  "Authorization": f"Bearer {self.api_key}",
79
+ "Content-Type": "application/json",
 
 
 
80
  },
81
+ json={"input": input.model_dump()},
82
+ timeout=None,
83
  )
84
  response.raise_for_status() # Raise an error for bad responses
85
  instance = response.json().get("output").get("instances")[0]
86
  print("Generated mask.")
87
  return RamGroundedSamResult(**instance)
88
  except httpx.HTTPStatusError as exc:
89
+ print(
90
+ f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
91
+ )
92
  except Exception as exc:
93
  print(f"An unexpected error occurred: {exc}")
94
+
95
  def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
96
  print(f"Magic replacing with input: {input.prompt}")
97
  try:
 
99
  f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
100
  headers={
101
  "Authorization": f"Bearer {self.api_key}",
102
+ "Content-Type": "application/json",
103
  },
104
+ json={"input": input.model_dump()},
105
+ timeout=None,
 
 
106
  )
107
  response.raise_for_status() # Raise an error for bad responses
108
  instance = response.json().get("output")
109
  print("Magic replaced.")
110
  return MagicReplaceResult(**instance)
111
  except httpx.HTTPStatusError as exc:
112
+ print(
113
+ f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
114
+ )
115
  except Exception as exc:
116
  print(f"An unexpected error occurred: {exc}")
117
 
 
121
  f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
122
  headers={
123
  "Authorization": f"Bearer {self.api_key}",
124
+ "Content-Type": "application/json",
125
  },
126
+ json={"input": input.model_dump()},
127
+ timeout=None,
 
 
128
  )
129
  response.raise_for_status() # Raise an error for bad responses
130
  instance = response.json().get("output")
131
  print("Super-resolutin done")
132
  return SuperResolutionResult(**instance)
133
  except httpx.HTTPStatusError as exc:
134
+ print(
135
+ f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
136
+ )
137
  except Exception as exc:
138
  print(f"An unexpected error occurred: {exc}")
139
 
140
 
141
  async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
142
+ input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url)
 
 
 
143
  mask = hopter.generate_mask(input)
144
  return mask.mask_b64
145
+
146
+
147
+ async def test_magic_replace(
148
+ hopter: Hopter, image_url: str, mask: str, prompt: str
149
+ ) -> str:
150
+ input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt)
 
151
  result = hopter.magic_replace(input)
152
  return result.base64_image
153
 
154
+
155
  async def main():
156
  hopter = Hopter(
157
+ api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING
 
158
  )
159
  image_file_path = "./assets/lakeview.jpg"
160
  image_url = image_path_to_uri(image_file_path)
161
 
162
  mask = await test_generate_mask(hopter, image_url)
163
+ magic_replace_result = await test_magic_replace(
164
+ hopter, image_url, mask, "remove the pole"
165
+ )
166
  print(magic_replace_result)
167
 
168
+
169
  if __name__ == "__main__":
170
  asyncio.run(main())
src/models/generate_mask_instruction.py CHANGED
@@ -1,19 +1,17 @@
1
  from pydantic import BaseModel, Field
2
 
 
3
  class GenerateMaskInstruction(BaseModel):
4
  category: str = Field(
5
  ...,
6
- description="The editing category based on the instruction. Must be one of: Addition, Remove, Local, Global, Background."
7
  )
8
  subject: str = Field(
9
  ...,
10
- description="The subject of the editing instruction. Must be a noun in no more than 5 words."
11
- )
12
- caption: str = Field(
13
- ...,
14
- description="The detailed description of the image."
15
  )
 
16
  target_caption: str = Field(
17
  ...,
18
- description="Apply the editing instruction to the image caption. The target caption should describe the image after the editing instruction is applied."
19
- )
 
1
  from pydantic import BaseModel, Field
2
 
3
+
4
  class GenerateMaskInstruction(BaseModel):
5
  category: str = Field(
6
  ...,
7
+ description="The editing category based on the instruction. Must be one of: Addition, Remove, Local, Global, Background.",
8
  )
9
  subject: str = Field(
10
  ...,
11
+ description="The subject of the editing instruction. Must be a noun in no more than 5 words.",
 
 
 
 
12
  )
13
+ caption: str = Field(..., description="The detailed description of the image.")
14
  target_caption: str = Field(
15
  ...,
16
+ description="Apply the editing instruction to the image caption. The target caption should describe the image after the editing instruction is applied.",
17
+ )
src/services/generate_mask.py CHANGED
@@ -6,6 +6,7 @@ from src.hopter.client import Hopter, RamGroundedSamInput, Environment
6
  from src.models.generate_mask_instruction import GenerateMaskInstruction
7
  from src.services.openai_file_upload import OpenAIFileUpload
8
  from src.utils import download_image_to_data_uri
 
9
  load_dotenv()
10
 
11
  system_prompt = """
@@ -37,6 +38,7 @@ Do not output 'sorry, xxx', even if it's a guess, directly output the answer you
37
  </task_3>
38
  """
39
 
 
40
  class GenerateMaskService:
41
  def __init__(self, hopter: Hopter):
42
  self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -44,38 +46,29 @@ class GenerateMaskService:
44
  self.openai_file_upload = OpenAIFileUpload()
45
  self.hopter = hopter
46
 
47
- def get_mask_generation_instruction(self, edit_instruction: str, image_url: str) -> GenerateMaskInstruction:
 
 
48
  messages = [
49
- {
50
- "role": "system",
51
- "content": system_prompt
52
- },
53
  {
54
  "role": "user",
55
  "content": [
56
- {
57
- "type": "text",
58
- "text": edit_instruction
59
- },
60
- {
61
- "type": "image_url",
62
- "image_url": {
63
- "url": image_url
64
- }
65
- }
66
- ]
67
- }
68
  ]
69
 
70
  response = self.llm.beta.chat.completions.parse(
71
- model=self.model,
72
- messages=messages,
73
- response_format=GenerateMaskInstruction
74
  )
75
  instruction = response.choices[0].message.parsed
76
  return instruction
77
-
78
- def generate_mask(self, mask_instruction: GenerateMaskInstruction, image_url: str) -> str:
 
 
79
  """
80
  Generate a mask for the image editing instruction.
81
 
@@ -87,14 +80,18 @@ class GenerateMaskService:
87
  """
88
  image_uri = download_image_to_data_uri(image_url)
89
  input = RamGroundedSamInput(
90
- text_prompt=mask_instruction.subject,
91
- image_b64=image_uri
92
  )
93
  generate_mask_result = self.hopter.generate_mask(input)
94
  return generate_mask_result.mask_b64
95
-
 
96
  async def main():
97
- service = GenerateMaskService(Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING))
 
 
 
 
98
  edit_instruction = "remove the light post"
99
  image_file_path = "./assets/lakeview.jpg"
100
  with open(image_file_path, "rb") as image_file:
@@ -105,5 +102,6 @@ async def main():
105
  mask = service.generate_mask(instruction, image_url)
106
  print(mask)
107
 
 
108
  if __name__ == "__main__":
109
  asyncio.run(main())
 
6
  from src.models.generate_mask_instruction import GenerateMaskInstruction
7
  from src.services.openai_file_upload import OpenAIFileUpload
8
  from src.utils import download_image_to_data_uri
9
+
10
  load_dotenv()
11
 
12
  system_prompt = """
 
38
  </task_3>
39
  """
40
 
41
+
42
  class GenerateMaskService:
43
  def __init__(self, hopter: Hopter):
44
  self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
 
46
  self.openai_file_upload = OpenAIFileUpload()
47
  self.hopter = hopter
48
 
49
+ def get_mask_generation_instruction(
50
+ self, edit_instruction: str, image_url: str
51
+ ) -> GenerateMaskInstruction:
52
  messages = [
53
+ {"role": "system", "content": system_prompt},
 
 
 
54
  {
55
  "role": "user",
56
  "content": [
57
+ {"type": "text", "text": edit_instruction},
58
+ {"type": "image_url", "image_url": {"url": image_url}},
59
+ ],
60
+ },
 
 
 
 
 
 
 
 
61
  ]
62
 
63
  response = self.llm.beta.chat.completions.parse(
64
+ model=self.model, messages=messages, response_format=GenerateMaskInstruction
 
 
65
  )
66
  instruction = response.choices[0].message.parsed
67
  return instruction
68
+
69
+ def generate_mask(
70
+ self, mask_instruction: GenerateMaskInstruction, image_url: str
71
+ ) -> str:
72
  """
73
  Generate a mask for the image editing instruction.
74
 
 
80
  """
81
  image_uri = download_image_to_data_uri(image_url)
82
  input = RamGroundedSamInput(
83
+ text_prompt=mask_instruction.subject, image_b64=image_uri
 
84
  )
85
  generate_mask_result = self.hopter.generate_mask(input)
86
  return generate_mask_result.mask_b64
87
+
88
+
89
  async def main():
90
+ service = GenerateMaskService(
91
+ Hopter(
92
+ api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
93
+ )
94
+ )
95
  edit_instruction = "remove the light post"
96
  image_file_path = "./assets/lakeview.jpg"
97
  with open(image_file_path, "rb") as image_file:
 
102
  mask = service.generate_mask(instruction, image_url)
103
  print(mask)
104
 
105
+
106
  if __name__ == "__main__":
107
  asyncio.run(main())
src/services/google_cloud_image_upload.py CHANGED
@@ -7,14 +7,18 @@ from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
 
10
  def get_credentials():
11
  credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
12
  # create a temp file with the credentials
13
- with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp_file:
 
 
14
  temp_file.write(credentials_json_string)
15
  temp_file_path = temp_file.name
16
  return temp_file_path
17
 
 
18
  class GoogleCloudImageUploadService:
19
  BUCKET_NAME = "picchat-assets"
20
  MAX_DIMENSION = 1024
@@ -39,37 +43,49 @@ class GoogleCloudImageUploadService:
39
  # Open and optionally resize the image, then save to a temporary file.
40
  with Image.open(source_file_name) as image:
41
  # Determine the original format. If it's not JPEG or PNG, default to JPEG.
42
- original_format = image.format.upper() if image.format in ['JPEG', 'PNG'] else "JPEG"
 
 
43
 
44
  # Resize if needed.
45
- if image.width > self.MAX_DIMENSION or image.height > self.MAX_DIMENSION:
 
 
 
46
  image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
47
-
48
  # Choose the file extension based on the image format.
49
  suffix = ".jpg" if original_format == "JPEG" else ".png"
50
-
51
  # Create a temporary file with the appropriate suffix.
52
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
 
 
53
  temp_filename = temp_file.name
54
  image.save(temp_filename, format=original_format)
55
 
56
  try:
57
  # Set content type based on the image format.
58
- content_type = "image/jpeg" if original_format == "JPEG" else "image/png"
 
 
59
  blob.upload_from_filename(temp_filename, content_type=content_type)
60
  blob.make_public()
61
  finally:
62
  # Remove the temporary file.
63
  os.remove(temp_filename)
64
 
65
- print(f"File {source_file_name} uploaded to {blob_name} in bucket {self.BUCKET_NAME}.")
 
 
66
  return blob.public_url
67
  except Exception as e:
68
  print(f"An error occurred: {e}")
69
  return None
70
 
 
71
  if __name__ == "__main__":
72
  image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
73
  upload_service = GoogleCloudImageUploadService()
74
  url = upload_service.upload_image_to_gcs(image)
75
- print(url)
 
7
 
8
  load_dotenv()
9
 
10
+
11
  def get_credentials():
12
  credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
13
  # create a temp file with the credentials
14
+ with tempfile.NamedTemporaryFile(
15
+ mode="w+", delete=False, suffix=".json"
16
+ ) as temp_file:
17
  temp_file.write(credentials_json_string)
18
  temp_file_path = temp_file.name
19
  return temp_file_path
20
 
21
+
22
  class GoogleCloudImageUploadService:
23
  BUCKET_NAME = "picchat-assets"
24
  MAX_DIMENSION = 1024
 
43
  # Open and optionally resize the image, then save to a temporary file.
44
  with Image.open(source_file_name) as image:
45
  # Determine the original format. If it's not JPEG or PNG, default to JPEG.
46
+ original_format = (
47
+ image.format.upper() if image.format in ["JPEG", "PNG"] else "JPEG"
48
+ )
49
 
50
  # Resize if needed.
51
+ if (
52
+ image.width > self.MAX_DIMENSION
53
+ or image.height > self.MAX_DIMENSION
54
+ ):
55
  image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
56
+
57
  # Choose the file extension based on the image format.
58
  suffix = ".jpg" if original_format == "JPEG" else ".png"
59
+
60
  # Create a temporary file with the appropriate suffix.
61
+ with tempfile.NamedTemporaryFile(
62
+ delete=False, suffix=suffix
63
+ ) as temp_file:
64
  temp_filename = temp_file.name
65
  image.save(temp_filename, format=original_format)
66
 
67
  try:
68
  # Set content type based on the image format.
69
+ content_type = (
70
+ "image/jpeg" if original_format == "JPEG" else "image/png"
71
+ )
72
  blob.upload_from_filename(temp_filename, content_type=content_type)
73
  blob.make_public()
74
  finally:
75
  # Remove the temporary file.
76
  os.remove(temp_filename)
77
 
78
+ print(
79
+ f"File {source_file_name} uploaded to {blob_name} in bucket {self.BUCKET_NAME}."
80
+ )
81
  return blob.public_url
82
  except Exception as e:
83
  print(f"An error occurred: {e}")
84
  return None
85
 
86
+
87
  if __name__ == "__main__":
88
  image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
89
  upload_service = GoogleCloudImageUploadService()
90
  url = upload_service.upload_image_to_gcs(image)
91
+ print(url)
src/services/image_uploader.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
  import os
6
  from pydantic import BaseModel
7
 
 
8
  class ImageInfo(BaseModel):
9
  filename: str
10
  name: str
@@ -12,6 +13,7 @@ class ImageInfo(BaseModel):
12
  extension: str
13
  url: str
14
 
 
15
  class ImgBBData(BaseModel):
16
  id: str
17
  title: str
@@ -28,33 +30,35 @@ class ImgBBData(BaseModel):
28
  medium: ImageInfo
29
  delete_url: str
30
 
 
31
  class ImgBBResponse(BaseModel):
32
  data: ImgBBData
33
  success: bool
34
  status: int
35
 
 
36
  class ImageUploader:
37
  """A class to handle image uploads to ImgBB service."""
38
-
39
  def __init__(self, api_key: str):
40
  """
41
  Initialize the ImageUploader with an API key.
42
-
43
  Args:
44
  api_key (str): The ImgBB API key
45
  """
46
  self.api_key = api_key
47
  self.base_url = "https://api.imgbb.com/1/upload"
48
-
49
  def upload(
50
  self,
51
  image: Union[str, bytes, Path],
52
  name: Optional[str] = None,
53
- expiration: Optional[int] = None
54
  ) -> ImgBBResponse:
55
  """
56
  Upload an image to ImgBB.
57
-
58
  Args:
59
  image: Can be:
60
  - A file path (str or Path)
@@ -64,20 +68,20 @@ class ImageUploader:
64
  - Bytes of an image
65
  name: Optional name for the uploaded file
66
  expiration: Optional expiration time in seconds (60-15552000)
67
-
68
  Returns:
69
  ImgBBResponse containing the parsed upload response from ImgBB
70
-
71
  Raises:
72
  ValueError: If the image format is invalid or upload fails
73
  requests.RequestException: If the API request fails
74
  """
75
  # Prepare the parameters
76
- params = {'key': self.api_key}
77
  if expiration:
78
  if not 60 <= expiration <= 15552000:
79
  raise ValueError("Expiration must be between 60 and 15552000 seconds")
80
- params['expiration'] = expiration
81
 
82
  # Handle different image input types
83
  if isinstance(image, (str, Path)):
@@ -85,38 +89,40 @@ class ImageUploader:
85
  files = {}
86
  if os.path.isfile(image_str):
87
  # It's a file path
88
- with open(image_str, 'rb') as file:
89
- files['image'] = file
90
- elif image_str.startswith(('http://', 'https://')):
91
  # It's a URL
92
- files['image'] = (None, image_str)
93
- elif image_str.startswith('data:image/'):
94
  # It's a data URI
95
  # Extract the base64 part after the comma
96
- base64_data = image_str.split(',', 1)[1]
97
- files['image'] = (None, base64_data)
98
  else:
99
  # Assume it's base64 data
100
- files['image'] = (None, image_str)
101
 
102
  if name:
103
- files['name'] = (None, name)
104
  response = requests.post(self.base_url, params=params, files=files)
105
  elif isinstance(image, bytes):
106
  # Convert bytes to base64
107
- base64_image = base64.b64encode(image).decode('utf-8')
108
- files = {
109
- 'image': (None, base64_image)
110
- }
111
  if name:
112
- files['name'] = (None, name)
113
  response = requests.post(self.base_url, params=params, files=files)
114
  else:
115
- raise ValueError("Invalid image format. Must be file path, URL, base64 string, or bytes")
 
 
116
 
117
  # Check the response
118
  if response.status_code != 200:
119
- raise ValueError(f"Upload failed with status {response.status_code}: {response.text}")
 
 
120
 
121
  # Parse the response using Pydantic model
122
  response_json = response.json()
@@ -126,16 +132,16 @@ class ImageUploader:
126
  self,
127
  file_path: Union[str, Path],
128
  name: Optional[str] = None,
129
- expiration: Optional[int] = None
130
  ) -> ImgBBResponse:
131
  """
132
  Convenience method to upload an image file.
133
-
134
  Args:
135
  file_path: Path to the image file
136
  name: Optional name for the uploaded file
137
  expiration: Optional expiration time in seconds (60-15552000)
138
-
139
  Returns:
140
  ImgBBResponse containing the parsed upload response from ImgBB
141
  """
@@ -145,16 +151,16 @@ class ImageUploader:
145
  self,
146
  image_url: str,
147
  name: Optional[str] = None,
148
- expiration: Optional[int] = None
149
  ) -> ImgBBResponse:
150
  """
151
  Convenience method to upload an image from a URL.
152
-
153
  Args:
154
  image_url: URL of the image to upload
155
  name: Optional name for the uploaded file
156
  expiration: Optional expiration time in seconds (60-15552000)
157
-
158
  Returns:
159
  ImgBBResponse containing the parsed upload response from ImgBB
160
  """
 
5
  import os
6
  from pydantic import BaseModel
7
 
8
+
9
  class ImageInfo(BaseModel):
10
  filename: str
11
  name: str
 
13
  extension: str
14
  url: str
15
 
16
+
17
  class ImgBBData(BaseModel):
18
  id: str
19
  title: str
 
30
  medium: ImageInfo
31
  delete_url: str
32
 
33
+
34
  class ImgBBResponse(BaseModel):
35
  data: ImgBBData
36
  success: bool
37
  status: int
38
 
39
+
40
  class ImageUploader:
41
  """A class to handle image uploads to ImgBB service."""
42
+
43
  def __init__(self, api_key: str):
44
  """
45
  Initialize the ImageUploader with an API key.
46
+
47
  Args:
48
  api_key (str): The ImgBB API key
49
  """
50
  self.api_key = api_key
51
  self.base_url = "https://api.imgbb.com/1/upload"
52
+
53
  def upload(
54
  self,
55
  image: Union[str, bytes, Path],
56
  name: Optional[str] = None,
57
+ expiration: Optional[int] = None,
58
  ) -> ImgBBResponse:
59
  """
60
  Upload an image to ImgBB.
61
+
62
  Args:
63
  image: Can be:
64
  - A file path (str or Path)
 
68
  - Bytes of an image
69
  name: Optional name for the uploaded file
70
  expiration: Optional expiration time in seconds (60-15552000)
71
+
72
  Returns:
73
  ImgBBResponse containing the parsed upload response from ImgBB
74
+
75
  Raises:
76
  ValueError: If the image format is invalid or upload fails
77
  requests.RequestException: If the API request fails
78
  """
79
  # Prepare the parameters
80
+ params = {"key": self.api_key}
81
  if expiration:
82
  if not 60 <= expiration <= 15552000:
83
  raise ValueError("Expiration must be between 60 and 15552000 seconds")
84
+ params["expiration"] = expiration
85
 
86
  # Handle different image input types
87
  if isinstance(image, (str, Path)):
 
89
  files = {}
90
  if os.path.isfile(image_str):
91
  # It's a file path
92
+ with open(image_str, "rb") as file:
93
+ files["image"] = file
94
+ elif image_str.startswith(("http://", "https://")):
95
  # It's a URL
96
+ files["image"] = (None, image_str)
97
+ elif image_str.startswith("data:image/"):
98
  # It's a data URI
99
  # Extract the base64 part after the comma
100
+ base64_data = image_str.split(",", 1)[1]
101
+ files["image"] = (None, base64_data)
102
  else:
103
  # Assume it's base64 data
104
+ files["image"] = (None, image_str)
105
 
106
  if name:
107
+ files["name"] = (None, name)
108
  response = requests.post(self.base_url, params=params, files=files)
109
  elif isinstance(image, bytes):
110
  # Convert bytes to base64
111
+ base64_image = base64.b64encode(image).decode("utf-8")
112
+ files = {"image": (None, base64_image)}
 
 
113
  if name:
114
+ files["name"] = (None, name)
115
  response = requests.post(self.base_url, params=params, files=files)
116
  else:
117
+ raise ValueError(
118
+ "Invalid image format. Must be file path, URL, base64 string, or bytes"
119
+ )
120
 
121
  # Check the response
122
  if response.status_code != 200:
123
+ raise ValueError(
124
+ f"Upload failed with status {response.status_code}: {response.text}"
125
+ )
126
 
127
  # Parse the response using Pydantic model
128
  response_json = response.json()
 
132
  self,
133
  file_path: Union[str, Path],
134
  name: Optional[str] = None,
135
+ expiration: Optional[int] = None,
136
  ) -> ImgBBResponse:
137
  """
138
  Convenience method to upload an image file.
139
+
140
  Args:
141
  file_path: Path to the image file
142
  name: Optional name for the uploaded file
143
  expiration: Optional expiration time in seconds (60-15552000)
144
+
145
  Returns:
146
  ImgBBResponse containing the parsed upload response from ImgBB
147
  """
 
151
  self,
152
  image_url: str,
153
  name: Optional[str] = None,
154
+ expiration: Optional[int] = None,
155
  ) -> ImgBBResponse:
156
  """
157
  Convenience method to upload an image from a URL.
158
+
159
  Args:
160
  image_url: URL of the image to upload
161
  name: Optional name for the uploaded file
162
  expiration: Optional expiration time in seconds (60-15552000)
163
+
164
  Returns:
165
  ImgBBResponse containing the parsed upload response from ImgBB
166
  """
src/services/openai_file_upload.py CHANGED
@@ -4,6 +4,7 @@ import os
4
 
5
  load_dotenv()
6
 
 
7
  class OpenAIFileUpload:
8
  def __init__(self):
9
  self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
 
4
 
5
  load_dotenv()
6
 
7
+
8
  class OpenAIFileUpload:
9
  def __init__(self):
10
  self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
src/utils.py CHANGED
@@ -4,16 +4,21 @@ from src.services.google_cloud_image_upload import GoogleCloudImageUploadService
4
  from PIL import Image
5
  from urllib.request import urlopen
6
  import io
 
 
7
  def image_path_to_base64(image_path: str) -> str:
8
  with open(image_path, "rb") as image_file:
9
  return base64.b64encode(image_file.read()).decode("utf-8")
10
 
 
11
  def upload_file_to_base64(file: UploadFile) -> str:
12
  return base64.b64encode(file.file.read()).decode("utf-8")
13
 
 
14
  def image_path_to_uri(image_path: str) -> str:
15
  return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
16
 
 
17
  def upload_image(image_path: str) -> str:
18
  """
19
  Upload an image to Google Cloud Storage and return the public URL.
@@ -27,6 +32,7 @@ def upload_image(image_path: str) -> str:
27
  upload_service = GoogleCloudImageUploadService()
28
  return upload_service.upload_image_to_gcs(image_path)
29
 
 
30
  def download_image_to_data_uri(image_url: str) -> str:
31
  # Open the image from the URL
32
  response = urlopen(image_url)
@@ -34,16 +40,20 @@ def download_image_to_data_uri(image_url: str) -> str:
34
 
35
  # Determine the image format; default to 'JPEG' if not found
36
  image_format = img.format if img.format is not None else "JPEG"
37
-
38
  # Build the MIME type; for 'JPEG', use 'image/jpeg'
39
- mime_type = "image/jpeg" if image_format.upper() == "JPEG" else f"image/{image_format.lower()}"
40
-
 
 
 
 
41
  # Save the image to an in-memory buffer using the detected format
42
  buffered = io.BytesIO()
43
  img.save(buffered, format=image_format)
44
-
45
  # Encode the image bytes to base64
46
  img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
47
-
48
  # Return the data URI with the correct MIME type
49
- return f"data:{mime_type};base64,{img_base64}"
 
4
  from PIL import Image
5
  from urllib.request import urlopen
6
  import io
7
+
8
+
9
  def image_path_to_base64(image_path: str) -> str:
10
  with open(image_path, "rb") as image_file:
11
  return base64.b64encode(image_file.read()).decode("utf-8")
12
 
13
+
14
  def upload_file_to_base64(file: UploadFile) -> str:
15
  return base64.b64encode(file.file.read()).decode("utf-8")
16
 
17
+
18
  def image_path_to_uri(image_path: str) -> str:
19
  return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
20
 
21
+
22
  def upload_image(image_path: str) -> str:
23
  """
24
  Upload an image to Google Cloud Storage and return the public URL.
 
32
  upload_service = GoogleCloudImageUploadService()
33
  return upload_service.upload_image_to_gcs(image_path)
34
 
35
+
36
  def download_image_to_data_uri(image_url: str) -> str:
37
  # Open the image from the URL
38
  response = urlopen(image_url)
 
40
 
41
  # Determine the image format; default to 'JPEG' if not found
42
  image_format = img.format if img.format is not None else "JPEG"
43
+
44
  # Build the MIME type; for 'JPEG', use 'image/jpeg'
45
+ mime_type = (
46
+ "image/jpeg"
47
+ if image_format.upper() == "JPEG"
48
+ else f"image/{image_format.lower()}"
49
+ )
50
+
51
  # Save the image to an in-memory buffer using the detected format
52
  buffered = io.BytesIO()
53
  img.save(buffered, format=image_format)
54
+
55
  # Encode the image bytes to base64
56
  img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
57
+
58
  # Return the data URI with the correct MIME type
59
+ return f"data:{mime_type};base64,{img_base64}"
stream_utils.py CHANGED
@@ -6,28 +6,29 @@ from rich.markdown import Markdown
6
  from rich.panel import Panel
7
  from rich.text import Text
8
 
 
9
  class StreamResponseHandler:
10
  """
11
  A utility class for handling streaming responses from API endpoints.
12
  Provides rich formatting and real-time updates of the response content.
13
  """
14
-
15
  def __init__(self, console=None):
16
  """
17
  Initialize the stream response handler.
18
-
19
  Args:
20
  console (Console, optional): A Rich console instance. If not provided, a new one will be created.
21
  """
22
  self.console = console or Console()
23
-
24
  def check_server_health(self, health_url="http://localhost:8000/health"):
25
  """
26
  Check if the server is running and accessible.
27
-
28
  Args:
29
  health_url (str, optional): The URL to check server health. Defaults to "http://localhost:8000/health".
30
-
31
  Returns:
32
  bool: True if the server is running and accessible, False otherwise.
33
  """
@@ -38,26 +39,32 @@ class StreamResponseHandler:
38
  self.console.print("[bold green]✓ Server is running and accessible.[/]")
39
  return True
40
  else:
41
- self.console.print(f"[bold red]✗ Server health check failed[/] with status code: {response.status_code}")
 
 
42
  return False
43
  except requests.exceptions.ConnectionError:
44
- self.console.print("[bold red]✗ Error:[/] Could not connect to the server. Make sure it's running.")
 
 
45
  return False
46
  except Exception as e:
47
  self.console.print(f"[bold red]✗ Error checking server health:[/] {e}")
48
  return False
49
-
50
- def stream_response(self, url, payload=None, params=None, method="POST", title="AI Response"):
 
 
51
  """
52
  Send a request to an endpoint and stream the output to the terminal.
53
-
54
  Args:
55
  url (str): The URL of the endpoint to send the request to.
56
  payload (dict, optional): The JSON payload to send in the request body. Defaults to None.
57
  params (dict, optional): The query parameters to send in the request. Defaults to None.
58
  method (str, optional): The HTTP method to use. Defaults to "POST".
59
  title (str, optional): The title to display in the panel. Defaults to "AI Response".
60
-
61
  Returns:
62
  bool: True if the streaming was successful, False otherwise.
63
  """
@@ -69,39 +76,42 @@ class StreamResponseHandler:
69
  if params:
70
  self.console.print("Parameters:", style="bold")
71
  self.console.print(json.dumps(params, indent=2))
72
-
73
  try:
74
  # Prepare the request
75
- request_kwargs = {
76
- "stream": True
77
- }
78
  if payload:
79
  request_kwargs["json"] = payload
80
  if params:
81
  request_kwargs["params"] = params
82
-
83
  # Make the request
84
  with getattr(requests, method.lower())(url, **request_kwargs) as response:
85
  # Check if the request was successful
86
  if response.status_code != 200:
87
- self.console.print(f"[bold red]Error:[/] Received status code {response.status_code}")
 
 
88
  self.console.print(f"Response: {response.text}")
89
  return False
90
-
91
  # Initialize an empty response text
92
  full_response = ""
93
-
94
  # Use Rich's Live display to update the content in place
95
- with Live(Panel("Waiting for response...", title=title, border_style="blue"), refresh_per_second=10) as live:
 
 
 
96
  # Process the streaming response
97
  for line in response.iter_lines():
98
  if line:
99
  # Decode the line and parse it as JSON
100
- decoded_line = line.decode('utf-8')
101
  try:
102
  # Parse the JSON
103
  data = json.loads(decoded_line)
104
-
105
  # Extract and display the content
106
  if isinstance(data, dict):
107
  if "content" in data:
@@ -111,35 +121,82 @@ class StreamResponseHandler:
111
  # Append to the full response
112
  full_response += text_content
113
  # Update the live display with the current full response
114
- live.update(Panel(Markdown(full_response), title=title, border_style="green"))
 
 
 
 
 
 
115
  elif content.get("type") == "image_url":
116
- image_url = content.get("image_url", {}).get("url", "")
 
 
117
  # Add a note about the image URL
118
- image_note = f"\n\n[Image URL: {image_url}]"
 
 
119
  full_response += image_note
120
- live.update(Panel(Markdown(full_response), title=title, border_style="green"))
 
 
 
 
 
 
121
  elif "edited_image_url" in data:
122
  # Handle edited image URL from edit endpoint
123
  image_url = data.get("edited_image_url", "")
124
- image_note = f"\n\n[Edited Image URL: {image_url}]"
 
 
125
  full_response += image_note
126
- live.update(Panel(Markdown(full_response), title=title, border_style="green"))
 
 
 
 
 
 
127
  else:
128
  # For other types of data, just show the JSON
129
- live.update(Panel(Text(json.dumps(data, indent=2)), title="Raw JSON Response", border_style="yellow"))
 
 
 
 
 
 
130
  else:
131
- live.update(Panel(Text(decoded_line), title="Raw Response", border_style="yellow"))
 
 
 
 
 
 
132
  except json.JSONDecodeError:
133
  # If it's not valid JSON, just show the raw line
134
- live.update(Panel(Text(f"Raw response: {decoded_line}"), title="Invalid JSON", border_style="red"))
 
 
 
 
 
 
135
 
136
  self.console.print("[bold green]Stream completed.[/]")
137
  return True
138
-
139
  except requests.exceptions.ConnectionError:
140
- self.console.print(f"[bold red]Error:[/] Could not connect to the server at {url}", style="red")
141
- self.console.print("Make sure the server is running and accessible.", style="red")
 
 
 
 
 
142
  return False
143
  except requests.exceptions.RequestException as e:
144
  self.console.print(f"[bold red]Error:[/] {e}", style="red")
145
- return False
 
6
  from rich.panel import Panel
7
  from rich.text import Text
8
 
9
+
10
  class StreamResponseHandler:
11
  """
12
  A utility class for handling streaming responses from API endpoints.
13
  Provides rich formatting and real-time updates of the response content.
14
  """
15
+
16
  def __init__(self, console=None):
17
  """
18
  Initialize the stream response handler.
19
+
20
  Args:
21
  console (Console, optional): A Rich console instance. If not provided, a new one will be created.
22
  """
23
  self.console = console or Console()
24
+
25
  def check_server_health(self, health_url="http://localhost:8000/health"):
26
  """
27
  Check if the server is running and accessible.
28
+
29
  Args:
30
  health_url (str, optional): The URL to check server health. Defaults to "http://localhost:8000/health".
31
+
32
  Returns:
33
  bool: True if the server is running and accessible, False otherwise.
34
  """
 
39
  self.console.print("[bold green]✓ Server is running and accessible.[/]")
40
  return True
41
  else:
42
+ self.console.print(
43
+ f"[bold red]✗ Server health check failed[/] with status code: {response.status_code}"
44
+ )
45
  return False
46
  except requests.exceptions.ConnectionError:
47
+ self.console.print(
48
+ "[bold red]✗ Error:[/] Could not connect to the server. Make sure it's running."
49
+ )
50
  return False
51
  except Exception as e:
52
  self.console.print(f"[bold red]✗ Error checking server health:[/] {e}")
53
  return False
54
+
55
+ def stream_response(
56
+ self, url, payload=None, params=None, method="POST", title="AI Response"
57
+ ):
58
  """
59
  Send a request to an endpoint and stream the output to the terminal.
60
+
61
  Args:
62
  url (str): The URL of the endpoint to send the request to.
63
  payload (dict, optional): The JSON payload to send in the request body. Defaults to None.
64
  params (dict, optional): The query parameters to send in the request. Defaults to None.
65
  method (str, optional): The HTTP method to use. Defaults to "POST".
66
  title (str, optional): The title to display in the panel. Defaults to "AI Response".
67
+
68
  Returns:
69
  bool: True if the streaming was successful, False otherwise.
70
  """
 
76
  if params:
77
  self.console.print("Parameters:", style="bold")
78
  self.console.print(json.dumps(params, indent=2))
79
+
80
  try:
81
  # Prepare the request
82
+ request_kwargs = {"stream": True}
 
 
83
  if payload:
84
  request_kwargs["json"] = payload
85
  if params:
86
  request_kwargs["params"] = params
87
+
88
  # Make the request
89
  with getattr(requests, method.lower())(url, **request_kwargs) as response:
90
  # Check if the request was successful
91
  if response.status_code != 200:
92
+ self.console.print(
93
+ f"[bold red]Error:[/] Received status code {response.status_code}"
94
+ )
95
  self.console.print(f"Response: {response.text}")
96
  return False
97
+
98
  # Initialize an empty response text
99
  full_response = ""
100
+
101
  # Use Rich's Live display to update the content in place
102
+ with Live(
103
+ Panel("Waiting for response...", title=title, border_style="blue"),
104
+ refresh_per_second=10,
105
+ ) as live:
106
  # Process the streaming response
107
  for line in response.iter_lines():
108
  if line:
109
  # Decode the line and parse it as JSON
110
+ decoded_line = line.decode("utf-8")
111
  try:
112
  # Parse the JSON
113
  data = json.loads(decoded_line)
114
+
115
  # Extract and display the content
116
  if isinstance(data, dict):
117
  if "content" in data:
 
121
  # Append to the full response
122
  full_response += text_content
123
  # Update the live display with the current full response
124
+ live.update(
125
+ Panel(
126
+ Markdown(full_response),
127
+ title=title,
128
+ border_style="green",
129
+ )
130
+ )
131
  elif content.get("type") == "image_url":
132
+ image_url = content.get(
133
+ "image_url", {}
134
+ ).get("url", "")
135
  # Add a note about the image URL
136
+ image_note = (
137
+ f"\n\n[Image URL: {image_url}]"
138
+ )
139
  full_response += image_note
140
+ live.update(
141
+ Panel(
142
+ Markdown(full_response),
143
+ title=title,
144
+ border_style="green",
145
+ )
146
+ )
147
  elif "edited_image_url" in data:
148
  # Handle edited image URL from edit endpoint
149
  image_url = data.get("edited_image_url", "")
150
+ image_note = (
151
+ f"\n\n[Edited Image URL: {image_url}]"
152
+ )
153
  full_response += image_note
154
+ live.update(
155
+ Panel(
156
+ Markdown(full_response),
157
+ title=title,
158
+ border_style="green",
159
+ )
160
+ )
161
  else:
162
  # For other types of data, just show the JSON
163
+ live.update(
164
+ Panel(
165
+ Text(json.dumps(data, indent=2)),
166
+ title="Raw JSON Response",
167
+ border_style="yellow",
168
+ )
169
+ )
170
  else:
171
+ live.update(
172
+ Panel(
173
+ Text(decoded_line),
174
+ title="Raw Response",
175
+ border_style="yellow",
176
+ )
177
+ )
178
  except json.JSONDecodeError:
179
  # If it's not valid JSON, just show the raw line
180
+ live.update(
181
+ Panel(
182
+ Text(f"Raw response: {decoded_line}"),
183
+ title="Invalid JSON",
184
+ border_style="red",
185
+ )
186
+ )
187
 
188
  self.console.print("[bold green]Stream completed.[/]")
189
  return True
190
+
191
  except requests.exceptions.ConnectionError:
192
+ self.console.print(
193
+ f"[bold red]Error:[/] Could not connect to the server at {url}",
194
+ style="red",
195
+ )
196
+ self.console.print(
197
+ "Make sure the server is running and accessible.", style="red"
198
+ )
199
  return False
200
  except requests.exceptions.RequestException as e:
201
  self.console.print(f"[bold red]Error:[/] {e}", style="red")
202
+ return False
test_edit_stream.py CHANGED
@@ -2,19 +2,20 @@ import argparse
2
  import os
3
  import sys
4
  import requests
5
- import json
6
  from dotenv import load_dotenv
7
  from stream_utils import StreamResponseHandler
8
 
9
  # Load environment variables
10
  load_dotenv()
11
 
 
12
  def get_default_image():
13
  """Get the default image path and convert it to a data URI."""
14
  image_path = "./assets/lakeview.jpg"
15
  if os.path.exists(image_path):
16
  try:
17
  from src.utils import image_path_to_uri
 
18
  image_uri = image_path_to_uri(image_path)
19
  print(f"Using default image: {image_path}")
20
  return image_uri
@@ -25,80 +26,93 @@ def get_default_image():
25
  print(f"Warning: Default image not found at {image_path}")
26
  return None
27
 
 
28
  def upload_image(handler, image_path):
29
  """
30
  Upload an image to the server.
31
-
32
  Args:
33
  handler (StreamResponseHandler): The stream response handler.
34
  image_path (str): Path to the image file to upload.
35
-
36
  Returns:
37
  str: The URL of the uploaded image, or None if upload failed.
38
  """
39
  if not os.path.exists(image_path):
40
- handler.console.print(f"[bold red]Error:[/] Image file not found at {image_path}")
 
 
41
  return None
42
-
43
  try:
44
  handler.console.print(f"Uploading image: [bold]{image_path}[/]")
45
- with open(image_path, 'rb') as f:
46
- files = {'file': (os.path.basename(image_path), f)}
47
  response = requests.post("http://localhost:8000/upload", files=files)
48
  if response.status_code == 200:
49
  image_url = response.json().get("image_url")
50
- handler.console.print(f"Image uploaded successfully. URL: [bold green]{image_url}[/]")
 
 
51
  return image_url
52
  else:
53
- handler.console.print(f"[bold red]Failed to upload image.[/] Status code: {response.status_code}")
 
 
54
  handler.console.print(f"Response: {response.text}")
55
  return None
56
  except Exception as e:
57
  handler.console.print(f"[bold red]Error uploading image:[/] {e}")
58
  return None
59
 
 
60
  def main():
61
  # Create a stream response handler
62
  handler = StreamResponseHandler()
63
-
64
  # Parse command line arguments
65
  parser = argparse.ArgumentParser(description="Test the image edit streaming API.")
66
- parser.add_argument("--instruction", "-i", required=True, help="The edit instruction.")
 
 
67
  parser.add_argument("--image", "-img", help="The URL of the image to edit.")
68
  parser.add_argument("--upload", "-u", help="Path to an image file to upload first.")
69
-
70
  args = parser.parse_args()
71
-
72
  # Check if the server is running
73
  if not handler.check_server_health():
74
  sys.exit(1)
75
-
76
  image_url = args.image
77
-
78
  # If upload is specified, upload the image first
79
  if args.upload:
80
  image_url = upload_image(handler, args.upload)
81
  if not image_url:
82
- handler.console.print("[yellow]Warning:[/] Failed to upload image. Continuing without image URL.")
83
-
 
 
84
  # Use the default image if no image URL is provided
85
  if not image_url:
86
  image_url = get_default_image()
87
  if not image_url:
88
- handler.console.print("[yellow]No image URL provided and default image not available.[/]")
 
 
89
  handler.console.print("The agent may ask for an image if needed.")
90
-
91
  # Prepare the payload for the edit request
92
- payload = {
93
- "edit_instruction": args.instruction
94
- }
95
-
96
  if image_url:
97
  payload["image_url"] = image_url
98
-
99
  # Stream the edit request
100
  endpoint_url = "http://localhost:8000/edit/stream"
101
  handler.stream_response(endpoint_url, payload=payload, title="Image Edit Response")
102
 
 
103
  if __name__ == "__main__":
104
- main()
 
2
  import os
3
  import sys
4
  import requests
 
5
  from dotenv import load_dotenv
6
  from stream_utils import StreamResponseHandler
7
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+
12
  def get_default_image():
13
  """Get the default image path and convert it to a data URI."""
14
  image_path = "./assets/lakeview.jpg"
15
  if os.path.exists(image_path):
16
  try:
17
  from src.utils import image_path_to_uri
18
+
19
  image_uri = image_path_to_uri(image_path)
20
  print(f"Using default image: {image_path}")
21
  return image_uri
 
26
  print(f"Warning: Default image not found at {image_path}")
27
  return None
28
 
29
+
30
  def upload_image(handler, image_path):
31
  """
32
  Upload an image to the server.
33
+
34
  Args:
35
  handler (StreamResponseHandler): The stream response handler.
36
  image_path (str): Path to the image file to upload.
37
+
38
  Returns:
39
  str: The URL of the uploaded image, or None if upload failed.
40
  """
41
  if not os.path.exists(image_path):
42
+ handler.console.print(
43
+ f"[bold red]Error:[/] Image file not found at {image_path}"
44
+ )
45
  return None
46
+
47
  try:
48
  handler.console.print(f"Uploading image: [bold]{image_path}[/]")
49
+ with open(image_path, "rb") as f:
50
+ files = {"file": (os.path.basename(image_path), f)}
51
  response = requests.post("http://localhost:8000/upload", files=files)
52
  if response.status_code == 200:
53
  image_url = response.json().get("image_url")
54
+ handler.console.print(
55
+ f"Image uploaded successfully. URL: [bold green]{image_url}[/]"
56
+ )
57
  return image_url
58
  else:
59
+ handler.console.print(
60
+ f"[bold red]Failed to upload image.[/] Status code: {response.status_code}"
61
+ )
62
  handler.console.print(f"Response: {response.text}")
63
  return None
64
  except Exception as e:
65
  handler.console.print(f"[bold red]Error uploading image:[/] {e}")
66
  return None
67
 
68
+
69
  def main():
70
  # Create a stream response handler
71
  handler = StreamResponseHandler()
72
+
73
  # Parse command line arguments
74
  parser = argparse.ArgumentParser(description="Test the image edit streaming API.")
75
+ parser.add_argument(
76
+ "--instruction", "-i", required=True, help="The edit instruction."
77
+ )
78
  parser.add_argument("--image", "-img", help="The URL of the image to edit.")
79
  parser.add_argument("--upload", "-u", help="Path to an image file to upload first.")
80
+
81
  args = parser.parse_args()
82
+
83
  # Check if the server is running
84
  if not handler.check_server_health():
85
  sys.exit(1)
86
+
87
  image_url = args.image
88
+
89
  # If upload is specified, upload the image first
90
  if args.upload:
91
  image_url = upload_image(handler, args.upload)
92
  if not image_url:
93
+ handler.console.print(
94
+ "[yellow]Warning:[/] Failed to upload image. Continuing without image URL."
95
+ )
96
+
97
  # Use the default image if no image URL is provided
98
  if not image_url:
99
  image_url = get_default_image()
100
  if not image_url:
101
+ handler.console.print(
102
+ "[yellow]No image URL provided and default image not available.[/]"
103
+ )
104
  handler.console.print("The agent may ask for an image if needed.")
105
+
106
  # Prepare the payload for the edit request
107
+ payload = {"edit_instruction": args.instruction}
108
+
 
 
109
  if image_url:
110
  payload["image_url"] = image_url
111
+
112
  # Stream the edit request
113
  endpoint_url = "http://localhost:8000/edit/stream"
114
  handler.stream_response(endpoint_url, payload=payload, title="Image Edit Response")
115
 
116
+
117
  if __name__ == "__main__":
118
+ main()
test_generic_stream.py CHANGED
@@ -6,24 +6,33 @@ from stream_utils import StreamResponseHandler
6
  # Load environment variables
7
  load_dotenv()
8
 
 
9
  def main():
10
  # Create a console for rich output
11
  handler = StreamResponseHandler()
12
-
13
  # Parse command line arguments
14
- parser = argparse.ArgumentParser(description="Test the generic agent streaming API.")
15
- parser.add_argument("--query", "-q", required=True, help="The query or message to send to the generic agent.")
16
-
 
 
 
 
 
 
 
17
  args = parser.parse_args()
18
-
19
  # Check if the server is running
20
  if not handler.check_server_health():
21
  sys.exit(1)
22
-
23
  # Stream the generic request
24
  endpoint_url = "http://localhost:8000/test/stream"
25
  params = {"query": args.query}
26
  handler.stream_response(endpoint_url, params=params, title="Generic Agent Response")
27
 
 
28
  if __name__ == "__main__":
29
- main()
 
6
  # Load environment variables
7
  load_dotenv()
8
 
9
+
10
  def main():
11
  # Create a console for rich output
12
  handler = StreamResponseHandler()
13
+
14
  # Parse command line arguments
15
+ parser = argparse.ArgumentParser(
16
+ description="Test the generic agent streaming API."
17
+ )
18
+ parser.add_argument(
19
+ "--query",
20
+ "-q",
21
+ required=True,
22
+ help="The query or message to send to the generic agent.",
23
+ )
24
+
25
  args = parser.parse_args()
26
+
27
  # Check if the server is running
28
  if not handler.check_server_health():
29
  sys.exit(1)
30
+
31
  # Stream the generic request
32
  endpoint_url = "http://localhost:8000/test/stream"
33
  params = {"query": args.query}
34
  handler.stream_response(endpoint_url, params=params, title="Generic Agent Response")
35
 
36
+
37
  if __name__ == "__main__":
38
+ main()