Spaces:
Running
Running
Commit
·
fcb8f25
1
Parent(s):
2a2c2ad
refactor: formatting
Browse files- app.py +1 -1
- image_edit_chat.py +53 -71
- image_edit_demo.py +30 -38
- server.py +36 -49
- src/agents/generic_agent.py +2 -7
- src/agents/image_edit_agent.py +31 -26
- src/agents/mask_generation_agent.py +4 -12
- src/hopter/client.py +54 -43
- src/models/generate_mask_instruction.py +6 -8
- src/services/generate_mask.py +24 -26
- src/services/google_cloud_image_upload.py +25 -9
- src/services/image_uploader.py +37 -31
- src/services/openai_file_upload.py +1 -0
- src/utils.py +16 -6
- stream_utils.py +92 -35
- test_edit_stream.py +39 -25
- test_generic_stream.py +16 -7
app.py
CHANGED
@@ -8,4 +8,4 @@ with demo.route("PicEdit"):
|
|
8 |
image_edit_demo.demo.render()
|
9 |
|
10 |
if __name__ == "__main__":
|
11 |
-
demo.launch()
|
|
|
8 |
image_edit_demo.demo.render()
|
9 |
|
10 |
if __name__ == "__main__":
|
11 |
+
demo.launch()
|
image_edit_chat.py
CHANGED
@@ -5,13 +5,10 @@ from src.hopter.client import Hopter, Environment
|
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
from src.utils import upload_image
|
8 |
-
from pydantic_ai.messages import
|
9 |
-
ToolCallPart,
|
10 |
-
ToolReturnPart
|
11 |
-
)
|
12 |
from pydantic_ai.models.openai import OpenAIModel
|
13 |
|
14 |
-
model = OpenAIModel(
|
15 |
"gpt-4o",
|
16 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
17 |
)
|
@@ -33,76 +30,65 @@ EXAMPLES = [
|
|
33 |
"text": "Replace the background to the space with stars and planets",
|
34 |
"files": [
|
35 |
"https://cdn.prod.website-files.com/66f230993926deadc0ac3a44/66f370d65f158cbbcfbcc532_Crossed%20Arms%20Levi%20Meir%20Clancy.jpg"
|
36 |
-
]
|
37 |
},
|
38 |
{
|
39 |
"text": "Change all the balloons to red in the image",
|
40 |
"files": [
|
41 |
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
|
42 |
-
]
|
43 |
},
|
44 |
{
|
45 |
"text": "Change coffee to a glass of water",
|
46 |
"files": [
|
47 |
"https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
|
48 |
-
]
|
49 |
},
|
50 |
{
|
51 |
"text": "ENHANCE!",
|
52 |
"files": [
|
53 |
"https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
|
54 |
-
]
|
55 |
-
}
|
56 |
]
|
57 |
|
58 |
load_dotenv()
|
59 |
|
|
|
60 |
def build_user_message(chat_input):
|
61 |
text = chat_input["text"]
|
62 |
images = chat_input["files"]
|
63 |
-
messages = [
|
64 |
-
{
|
65 |
-
"role": "user",
|
66 |
-
"content": text
|
67 |
-
}
|
68 |
-
]
|
69 |
if images:
|
70 |
-
messages.extend(
|
71 |
-
{
|
72 |
-
|
73 |
-
"content": {"path": image}
|
74 |
-
}
|
75 |
-
for image in images
|
76 |
-
])
|
77 |
return messages
|
78 |
|
|
|
79 |
def build_messages_for_agent(chat_input, past_messages):
|
80 |
# filter out image messages from past messages to save on tokens
|
81 |
messages = past_messages
|
82 |
-
|
83 |
# add the user's text message
|
84 |
if chat_input["text"]:
|
85 |
-
messages.append({
|
86 |
-
"type": "text",
|
87 |
-
"text": chat_input["text"]
|
88 |
-
})
|
89 |
|
90 |
# add the user's image message
|
91 |
files = chat_input.get("files", [])
|
92 |
image_url = upload_image(files[0]) if files else None
|
93 |
if image_url:
|
94 |
-
messages.append({
|
95 |
-
"type": "image_url",
|
96 |
-
"image_url": {"url": image_url}
|
97 |
-
})
|
98 |
|
99 |
return messages
|
100 |
-
|
|
|
101 |
def select_example(x: gr.SelectData, chat_input):
|
102 |
chat_input["text"] = x.value["text"]
|
103 |
chat_input["files"] = x.value["files"]
|
104 |
return chat_input
|
105 |
|
|
|
106 |
async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
107 |
# Prepare messages for the UI
|
108 |
chatbot.extend(build_user_message(chat_input))
|
@@ -113,15 +99,10 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
|
113 |
files = chat_input.get("files", [])
|
114 |
image_url = upload_image(files[0]) if files else None
|
115 |
messages = [
|
116 |
-
{
|
117 |
-
"type": "text",
|
118 |
-
"text": text
|
119 |
-
},
|
120 |
]
|
121 |
if image_url:
|
122 |
-
messages.append(
|
123 |
-
{"type": "image_url", "image_url": {"url": image_url}}
|
124 |
-
)
|
125 |
current_image = image_url
|
126 |
|
127 |
# Dependencies
|
@@ -131,66 +112,67 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
|
131 |
edit_instruction=text,
|
132 |
image_url=current_image,
|
133 |
hopter_client=hopter,
|
134 |
-
mask_service=mask_service
|
135 |
)
|
136 |
|
137 |
# Run the agent
|
138 |
async with image_edit_agent.run_stream(
|
139 |
-
messages,
|
140 |
-
deps=deps,
|
141 |
-
message_history=past_messages
|
142 |
) as result:
|
143 |
for message in result.new_messages():
|
144 |
for call in message.parts:
|
145 |
if isinstance(call, ToolCallPart):
|
146 |
call_args = (
|
147 |
call.args.args_json
|
148 |
-
if hasattr(call.args,
|
149 |
else call.args
|
150 |
)
|
151 |
metadata = {
|
152 |
-
|
153 |
}
|
154 |
# set the tool call id so that when the tool returns
|
155 |
# we can find this message and update with the result
|
156 |
if call.tool_call_id is not None:
|
157 |
-
metadata[
|
158 |
|
159 |
# Create a tool call message to show on the UI
|
160 |
gr_message = {
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
}
|
165 |
chatbot.append(gr_message)
|
166 |
if isinstance(call, ToolReturnPart):
|
167 |
for gr_message in chatbot:
|
168 |
# Skip messages without metadata
|
169 |
-
if not gr_message.get(
|
170 |
continue
|
171 |
|
172 |
-
if gr_message[
|
173 |
if isinstance(call.content, EditImageResult):
|
174 |
-
chatbot.append(
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
179 |
current_image = call.content.edited_image_url
|
180 |
else:
|
181 |
-
gr_message[
|
182 |
-
f'\nOutput: {call.content}'
|
183 |
-
)
|
184 |
yield gr.skip(), chatbot, gr.skip(), gr.skip()
|
185 |
|
186 |
-
chatbot.append({
|
187 |
async for message in result.stream_text():
|
188 |
-
chatbot[-1][
|
189 |
yield gr.skip(), chatbot, gr.skip(), gr.skip()
|
190 |
past_messages = result.all_messages()
|
191 |
|
192 |
yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
|
193 |
|
|
|
194 |
with gr.Blocks() as demo:
|
195 |
gr.Markdown(INTRO)
|
196 |
|
@@ -198,10 +180,10 @@ with gr.Blocks() as demo:
|
|
198 |
past_messages = gr.State([])
|
199 |
chatbot = gr.Chatbot(
|
200 |
elem_id="chatbot",
|
201 |
-
label=
|
202 |
-
type=
|
203 |
-
avatar_images=(None,
|
204 |
-
examples=EXAMPLES
|
205 |
)
|
206 |
|
207 |
with gr.Row():
|
@@ -209,8 +191,8 @@ with gr.Blocks() as demo:
|
|
209 |
interactive=True,
|
210 |
file_count="single",
|
211 |
show_label=False,
|
212 |
-
placeholder=
|
213 |
-
sources=["upload"]
|
214 |
)
|
215 |
generation = chat_input.submit(
|
216 |
stream_from_agent,
|
@@ -233,7 +215,7 @@ with gr.Blocks() as demo:
|
|
233 |
inputs=[chat_input],
|
234 |
outputs=[chat_input],
|
235 |
)
|
236 |
-
|
237 |
|
238 |
-
|
239 |
-
|
|
|
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
from src.utils import upload_image
|
8 |
+
from pydantic_ai.messages import ToolCallPart, ToolReturnPart
|
|
|
|
|
|
|
9 |
from pydantic_ai.models.openai import OpenAIModel
|
10 |
|
11 |
+
model = OpenAIModel(
|
12 |
"gpt-4o",
|
13 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
14 |
)
|
|
|
30 |
"text": "Replace the background to the space with stars and planets",
|
31 |
"files": [
|
32 |
"https://cdn.prod.website-files.com/66f230993926deadc0ac3a44/66f370d65f158cbbcfbcc532_Crossed%20Arms%20Levi%20Meir%20Clancy.jpg"
|
33 |
+
],
|
34 |
},
|
35 |
{
|
36 |
"text": "Change all the balloons to red in the image",
|
37 |
"files": [
|
38 |
"https://www.apple.com/tv-pr/articles/2024/10/apple-tv-unveils-severance-season-two-teaser-ahead-of-the-highly-anticipated-return-of-the-emmy-and-peabody-award-winning-phenomenon/images/big-image/big-image-01/1023024_Severance_Season_Two_Official_Trailer_Big_Image_01_big_image_post.jpg.large_2x.jpg"
|
39 |
+
],
|
40 |
},
|
41 |
{
|
42 |
"text": "Change coffee to a glass of water",
|
43 |
"files": [
|
44 |
"https://previews.123rf.com/images/vadymvdrobot/vadymvdrobot1812/vadymvdrobot181201149/113217373-image-of-smiling-woman-holding-takeaway-coffee-in-paper-cup-and-taking-selfie-while-walking-through.jpg"
|
45 |
+
],
|
46 |
},
|
47 |
{
|
48 |
"text": "ENHANCE!",
|
49 |
"files": [
|
50 |
"https://m.media-amazon.com/images/M/MV5BNzM3ODc5NzEtNzJkOC00MDM4LWI0MTYtZTkyNmY3ZTBhYzkxXkEyXkFqcGc@._V1_QL75_UX1000_CR0,52,1000,563_.jpg"
|
51 |
+
],
|
52 |
+
},
|
53 |
]
|
54 |
|
55 |
load_dotenv()
|
56 |
|
57 |
+
|
58 |
def build_user_message(chat_input):
|
59 |
text = chat_input["text"]
|
60 |
images = chat_input["files"]
|
61 |
+
messages = [{"role": "user", "content": text}]
|
|
|
|
|
|
|
|
|
|
|
62 |
if images:
|
63 |
+
messages.extend(
|
64 |
+
[{"role": "user", "content": {"path": image}} for image in images]
|
65 |
+
)
|
|
|
|
|
|
|
|
|
66 |
return messages
|
67 |
|
68 |
+
|
69 |
def build_messages_for_agent(chat_input, past_messages):
|
70 |
# filter out image messages from past messages to save on tokens
|
71 |
messages = past_messages
|
72 |
+
|
73 |
# add the user's text message
|
74 |
if chat_input["text"]:
|
75 |
+
messages.append({"type": "text", "text": chat_input["text"]})
|
|
|
|
|
|
|
76 |
|
77 |
# add the user's image message
|
78 |
files = chat_input.get("files", [])
|
79 |
image_url = upload_image(files[0]) if files else None
|
80 |
if image_url:
|
81 |
+
messages.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
|
|
|
|
|
82 |
|
83 |
return messages
|
84 |
+
|
85 |
+
|
86 |
def select_example(x: gr.SelectData, chat_input):
|
87 |
chat_input["text"] = x.value["text"]
|
88 |
chat_input["files"] = x.value["files"]
|
89 |
return chat_input
|
90 |
|
91 |
+
|
92 |
async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
93 |
# Prepare messages for the UI
|
94 |
chatbot.extend(build_user_message(chat_input))
|
|
|
99 |
files = chat_input.get("files", [])
|
100 |
image_url = upload_image(files[0]) if files else None
|
101 |
messages = [
|
102 |
+
{"type": "text", "text": text},
|
|
|
|
|
|
|
103 |
]
|
104 |
if image_url:
|
105 |
+
messages.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
|
|
|
106 |
current_image = image_url
|
107 |
|
108 |
# Dependencies
|
|
|
112 |
edit_instruction=text,
|
113 |
image_url=current_image,
|
114 |
hopter_client=hopter,
|
115 |
+
mask_service=mask_service,
|
116 |
)
|
117 |
|
118 |
# Run the agent
|
119 |
async with image_edit_agent.run_stream(
|
120 |
+
messages, deps=deps, message_history=past_messages
|
|
|
|
|
121 |
) as result:
|
122 |
for message in result.new_messages():
|
123 |
for call in message.parts:
|
124 |
if isinstance(call, ToolCallPart):
|
125 |
call_args = (
|
126 |
call.args.args_json
|
127 |
+
if hasattr(call.args, "args_json")
|
128 |
else call.args
|
129 |
)
|
130 |
metadata = {
|
131 |
+
"title": f"🛠️ Using {call.tool_name}",
|
132 |
}
|
133 |
# set the tool call id so that when the tool returns
|
134 |
# we can find this message and update with the result
|
135 |
if call.tool_call_id is not None:
|
136 |
+
metadata["id"] = call.tool_call_id
|
137 |
|
138 |
# Create a tool call message to show on the UI
|
139 |
gr_message = {
|
140 |
+
"role": "assistant",
|
141 |
+
"content": "Parameters: " + call_args,
|
142 |
+
"metadata": metadata,
|
143 |
}
|
144 |
chatbot.append(gr_message)
|
145 |
if isinstance(call, ToolReturnPart):
|
146 |
for gr_message in chatbot:
|
147 |
# Skip messages without metadata
|
148 |
+
if not gr_message.get("metadata"):
|
149 |
continue
|
150 |
|
151 |
+
if gr_message["metadata"].get("id", "") == call.tool_call_id:
|
152 |
if isinstance(call.content, EditImageResult):
|
153 |
+
chatbot.append(
|
154 |
+
{
|
155 |
+
"role": "assistant",
|
156 |
+
"content": gr.Image(
|
157 |
+
call.content.edited_image_url
|
158 |
+
),
|
159 |
+
"files": [call.content.edited_image_url],
|
160 |
+
}
|
161 |
+
)
|
162 |
current_image = call.content.edited_image_url
|
163 |
else:
|
164 |
+
gr_message["content"] += f"\nOutput: {call.content}"
|
|
|
|
|
165 |
yield gr.skip(), chatbot, gr.skip(), gr.skip()
|
166 |
|
167 |
+
chatbot.append({"role": "assistant", "content": ""})
|
168 |
async for message in result.stream_text():
|
169 |
+
chatbot[-1]["content"] = message
|
170 |
yield gr.skip(), chatbot, gr.skip(), gr.skip()
|
171 |
past_messages = result.all_messages()
|
172 |
|
173 |
yield gr.Textbox(interactive=True), gr.skip(), past_messages, current_image
|
174 |
|
175 |
+
|
176 |
with gr.Blocks() as demo:
|
177 |
gr.Markdown(INTRO)
|
178 |
|
|
|
180 |
past_messages = gr.State([])
|
181 |
chatbot = gr.Chatbot(
|
182 |
elem_id="chatbot",
|
183 |
+
label="Image Editing Assistant",
|
184 |
+
type="messages",
|
185 |
+
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"),
|
186 |
+
examples=EXAMPLES,
|
187 |
)
|
188 |
|
189 |
with gr.Row():
|
|
|
191 |
interactive=True,
|
192 |
file_count="single",
|
193 |
show_label=False,
|
194 |
+
placeholder="How would you like to edit this image?",
|
195 |
+
sources=["upload"],
|
196 |
)
|
197 |
generation = chat_input.submit(
|
198 |
stream_from_agent,
|
|
|
215 |
inputs=[chat_input],
|
216 |
outputs=[chat_input],
|
217 |
)
|
|
|
218 |
|
219 |
+
|
220 |
+
if __name__ == "__main__":
|
221 |
+
demo.launch()
|
image_edit_demo.py
CHANGED
@@ -4,50 +4,47 @@ import os
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
-
from pydantic_ai.messages import
|
8 |
-
ToolReturnPart
|
9 |
-
)
|
10 |
from src.utils import upload_image
|
|
|
11 |
load_dotenv()
|
12 |
|
|
|
13 |
async def process_edit(image, instruction):
|
14 |
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
15 |
mask_service = GenerateMaskService(hopter=hopter)
|
16 |
image_url = upload_image(image)
|
17 |
messages = [
|
18 |
-
{
|
19 |
-
"type": "text",
|
20 |
-
"text": instruction
|
21 |
-
},
|
22 |
]
|
23 |
if image:
|
24 |
-
messages.append(
|
25 |
-
{"type": "image_url", "image_url": {"url": image_url}}
|
26 |
-
)
|
27 |
deps = ImageEditDeps(
|
28 |
edit_instruction=instruction,
|
29 |
image_url=image_url,
|
30 |
hopter_client=hopter,
|
31 |
-
mask_service=mask_service
|
32 |
-
)
|
33 |
-
result = await image_edit_agent.run(
|
34 |
-
messages,
|
35 |
-
deps=deps
|
36 |
)
|
|
|
37 |
# Extract the edited image URL from the tool return
|
38 |
for message in result.new_messages():
|
39 |
for part in message.parts:
|
40 |
-
if isinstance(part, ToolReturnPart) and isinstance(
|
|
|
|
|
41 |
return part.content.edited_image_url
|
42 |
return None
|
43 |
|
|
|
44 |
async def use_edited_image(edited_image):
|
45 |
return edited_image
|
46 |
|
|
|
47 |
def clear_instruction():
|
48 |
# Only clear the instruction text.
|
49 |
return ""
|
50 |
|
|
|
51 |
# Create the Gradio interface
|
52 |
with gr.Blocks() as demo:
|
53 |
gr.Markdown("# PicEdit")
|
@@ -55,57 +52,52 @@ with gr.Blocks() as demo:
|
|
55 |
Welcome to PicEdit - an AI-powered image editing tool.
|
56 |
Simply upload an image and describe the changes you want to make in natural language.
|
57 |
""")
|
58 |
-
|
59 |
with gr.Row():
|
60 |
# Input image on the left
|
61 |
input_image = gr.Image(label="Original Image", type="filepath")
|
62 |
-
|
63 |
with gr.Column():
|
64 |
# Output image on the right
|
65 |
-
output_image = gr.Image(
|
|
|
|
|
66 |
use_edited_btn = gr.Button("👈 Use Edited Image 👈")
|
67 |
|
68 |
# Text input for editing instructions
|
69 |
instruction = gr.Textbox(
|
70 |
label="Editing Instructions",
|
71 |
-
placeholder="Describe the changes you want to make to the image..."
|
72 |
)
|
73 |
-
|
74 |
# Clear button
|
75 |
with gr.Row():
|
76 |
clear_btn = gr.Button("Clear")
|
77 |
submit_btn = gr.Button("Apply Edit", variant="primary")
|
78 |
-
|
79 |
# Set up the event handlers
|
80 |
submit_btn.click(
|
81 |
-
fn=process_edit,
|
82 |
-
inputs=[input_image, instruction],
|
83 |
-
outputs=output_image
|
84 |
)
|
85 |
-
|
86 |
use_edited_btn.click(
|
87 |
-
fn=use_edited_image,
|
88 |
-
inputs=[output_image],
|
89 |
-
outputs=[input_image]
|
90 |
)
|
91 |
|
92 |
# Bind the clear button's click event to only clear the instruction textbox.
|
93 |
-
clear_btn.click(
|
94 |
-
fn=clear_instruction,
|
95 |
-
inputs=[],
|
96 |
-
outputs=[instruction]
|
97 |
-
)
|
98 |
|
99 |
examples = gr.Examples(
|
100 |
examples=[
|
101 |
["https://i.ibb.co/qYwhcc6j/c837c212afbf.jpg", "remove the pole"],
|
102 |
["https://i.ibb.co/2Mrxztw/image.png", "replace the cat with a dog"],
|
103 |
-
[
|
|
|
|
|
|
|
104 |
],
|
105 |
-
inputs=[input_image, instruction]
|
106 |
)
|
107 |
|
108 |
if __name__ == "__main__":
|
109 |
demo.launch()
|
110 |
-
|
111 |
-
|
|
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
6 |
from dotenv import load_dotenv
|
7 |
+
from pydantic_ai.messages import ToolReturnPart
|
|
|
|
|
8 |
from src.utils import upload_image
|
9 |
+
|
10 |
load_dotenv()
|
11 |
|
12 |
+
|
13 |
async def process_edit(image, instruction):
|
14 |
hopter = Hopter(os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
15 |
mask_service = GenerateMaskService(hopter=hopter)
|
16 |
image_url = upload_image(image)
|
17 |
messages = [
|
18 |
+
{"type": "text", "text": instruction},
|
|
|
|
|
|
|
19 |
]
|
20 |
if image:
|
21 |
+
messages.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
|
|
|
22 |
deps = ImageEditDeps(
|
23 |
edit_instruction=instruction,
|
24 |
image_url=image_url,
|
25 |
hopter_client=hopter,
|
26 |
+
mask_service=mask_service,
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
+
result = await image_edit_agent.run(messages, deps=deps)
|
29 |
# Extract the edited image URL from the tool return
|
30 |
for message in result.new_messages():
|
31 |
for part in message.parts:
|
32 |
+
if isinstance(part, ToolReturnPart) and isinstance(
|
33 |
+
part.content, EditImageResult
|
34 |
+
):
|
35 |
return part.content.edited_image_url
|
36 |
return None
|
37 |
|
38 |
+
|
39 |
async def use_edited_image(edited_image):
|
40 |
return edited_image
|
41 |
|
42 |
+
|
43 |
def clear_instruction():
|
44 |
# Only clear the instruction text.
|
45 |
return ""
|
46 |
|
47 |
+
|
48 |
# Create the Gradio interface
|
49 |
with gr.Blocks() as demo:
|
50 |
gr.Markdown("# PicEdit")
|
|
|
52 |
Welcome to PicEdit - an AI-powered image editing tool.
|
53 |
Simply upload an image and describe the changes you want to make in natural language.
|
54 |
""")
|
55 |
+
|
56 |
with gr.Row():
|
57 |
# Input image on the left
|
58 |
input_image = gr.Image(label="Original Image", type="filepath")
|
59 |
+
|
60 |
with gr.Column():
|
61 |
# Output image on the right
|
62 |
+
output_image = gr.Image(
|
63 |
+
label="Edited Image", type="filepath", interactive=False, scale=3
|
64 |
+
)
|
65 |
use_edited_btn = gr.Button("👈 Use Edited Image 👈")
|
66 |
|
67 |
# Text input for editing instructions
|
68 |
instruction = gr.Textbox(
|
69 |
label="Editing Instructions",
|
70 |
+
placeholder="Describe the changes you want to make to the image...",
|
71 |
)
|
72 |
+
|
73 |
# Clear button
|
74 |
with gr.Row():
|
75 |
clear_btn = gr.Button("Clear")
|
76 |
submit_btn = gr.Button("Apply Edit", variant="primary")
|
77 |
+
|
78 |
# Set up the event handlers
|
79 |
submit_btn.click(
|
80 |
+
fn=process_edit, inputs=[input_image, instruction], outputs=output_image
|
|
|
|
|
81 |
)
|
82 |
+
|
83 |
use_edited_btn.click(
|
84 |
+
fn=use_edited_image, inputs=[output_image], outputs=[input_image]
|
|
|
|
|
85 |
)
|
86 |
|
87 |
# Bind the clear button's click event to only clear the instruction textbox.
|
88 |
+
clear_btn.click(fn=clear_instruction, inputs=[], outputs=[instruction])
|
|
|
|
|
|
|
|
|
89 |
|
90 |
examples = gr.Examples(
|
91 |
examples=[
|
92 |
["https://i.ibb.co/qYwhcc6j/c837c212afbf.jpg", "remove the pole"],
|
93 |
["https://i.ibb.co/2Mrxztw/image.png", "replace the cat with a dog"],
|
94 |
+
[
|
95 |
+
"https://i.ibb.co/9mT4cvnt/resized-78-B40-C09-1037-4-DD3-9-F48-D73637-EE4-E51.png",
|
96 |
+
"ENHANCE!",
|
97 |
+
],
|
98 |
],
|
99 |
+
inputs=[input_image, instruction],
|
100 |
)
|
101 |
|
102 |
if __name__ == "__main__":
|
103 |
demo.launch()
|
|
|
|
server.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
-
from fastapi import FastAPI, UploadFile, File,
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
-
import asyncio
|
5 |
import os
|
6 |
from dotenv import load_dotenv
|
7 |
-
from typing import Optional, List, Dict
|
8 |
import json
|
9 |
from pydantic import BaseModel
|
10 |
|
@@ -13,7 +12,7 @@ from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps
|
|
13 |
from src.agents.generic_agent import generic_agent
|
14 |
from src.hopter.client import Hopter, Environment
|
15 |
from src.services.generate_mask import GenerateMaskService
|
16 |
-
from src.utils import
|
17 |
|
18 |
# Load environment variables
|
19 |
load_dotenv()
|
@@ -29,15 +28,18 @@ app.add_middleware(
|
|
29 |
allow_headers=["*"], # Allows all headers
|
30 |
)
|
31 |
|
|
|
32 |
class EditRequest(BaseModel):
|
33 |
edit_instruction: str
|
34 |
image_url: Optional[str] = None
|
35 |
|
|
|
36 |
class MessageContent(BaseModel):
|
37 |
type: str
|
38 |
text: Optional[str] = None
|
39 |
image_url: Optional[Dict[str, str]] = None
|
40 |
|
|
|
41 |
class Message(BaseModel):
|
42 |
content: List[MessageContent]
|
43 |
|
@@ -48,9 +50,10 @@ async def test(query: str):
|
|
48 |
async with generic_agent.run_stream(query) as result:
|
49 |
async for message in result.stream(debounce_by=0.01):
|
50 |
yield json.dumps(message) + "\n"
|
51 |
-
|
52 |
return StreamingResponse(stream_messages(), media_type="text/plain")
|
53 |
|
|
|
54 |
@app.post("/edit")
|
55 |
async def edit_image(request: EditRequest):
|
56 |
"""
|
@@ -60,8 +63,7 @@ async def edit_image(request: EditRequest):
|
|
60 |
try:
|
61 |
# Initialize services
|
62 |
hopter = Hopter(
|
63 |
-
api_key=os.environ.get("HOPTER_API_KEY"),
|
64 |
-
environment=Environment.STAGING
|
65 |
)
|
66 |
mask_service = GenerateMaskService(hopter=hopter)
|
67 |
|
@@ -70,34 +72,27 @@ async def edit_image(request: EditRequest):
|
|
70 |
edit_instruction=request.edit_instruction,
|
71 |
image_url=request.image_url,
|
72 |
hopter_client=hopter,
|
73 |
-
mask_service=mask_service
|
74 |
)
|
75 |
|
76 |
# Create messages
|
77 |
-
messages = [
|
78 |
-
|
79 |
-
"type": "text",
|
80 |
-
"text": request.edit_instruction
|
81 |
-
}
|
82 |
-
]
|
83 |
-
|
84 |
if request.image_url:
|
85 |
-
messages.append(
|
86 |
-
"type": "image_url",
|
87 |
-
|
88 |
-
"url": request.image_url
|
89 |
-
}
|
90 |
-
})
|
91 |
|
92 |
# Run the agent
|
93 |
result = await image_edit_agent.run(messages, deps=deps)
|
94 |
-
|
95 |
# Return the result
|
96 |
return {"edited_image_url": result.edited_image_url}
|
97 |
-
|
98 |
except Exception as e:
|
99 |
raise HTTPException(status_code=500, detail=str(e))
|
100 |
|
|
|
101 |
@app.post("/edit/stream")
|
102 |
async def edit_image_stream(request: EditRequest):
|
103 |
"""
|
@@ -107,8 +102,7 @@ async def edit_image_stream(request: EditRequest):
|
|
107 |
try:
|
108 |
# Initialize services
|
109 |
hopter = Hopter(
|
110 |
-
api_key=os.environ.get("HOPTER_API_KEY"),
|
111 |
-
environment=Environment.STAGING
|
112 |
)
|
113 |
mask_service = GenerateMaskService(hopter=hopter)
|
114 |
|
@@ -117,24 +111,16 @@ async def edit_image_stream(request: EditRequest):
|
|
117 |
edit_instruction=request.edit_instruction,
|
118 |
image_url=request.image_url,
|
119 |
hopter_client=hopter,
|
120 |
-
mask_service=mask_service
|
121 |
)
|
122 |
|
123 |
# Create messages
|
124 |
-
messages = [
|
125 |
-
|
126 |
-
"type": "text",
|
127 |
-
"text": request.edit_instruction
|
128 |
-
}
|
129 |
-
]
|
130 |
-
|
131 |
if request.image_url:
|
132 |
-
messages.append(
|
133 |
-
"type": "image_url",
|
134 |
-
|
135 |
-
"url": request.image_url
|
136 |
-
}
|
137 |
-
})
|
138 |
|
139 |
async def stream_generator():
|
140 |
async with image_edit_agent.run_stream(messages, deps=deps) as result:
|
@@ -142,14 +128,12 @@ async def edit_image_stream(request: EditRequest):
|
|
142 |
# Convert message to JSON and yield
|
143 |
yield json.dumps(message) + "\n"
|
144 |
|
145 |
-
return StreamingResponse(
|
146 |
-
|
147 |
-
media_type="application/x-ndjson"
|
148 |
-
)
|
149 |
-
|
150 |
except Exception as e:
|
151 |
raise HTTPException(status_code=500, detail=str(e))
|
152 |
|
|
|
153 |
@app.post("/upload")
|
154 |
async def upload_image_file(file: UploadFile = File(...)):
|
155 |
"""
|
@@ -160,18 +144,19 @@ async def upload_image_file(file: UploadFile = File(...)):
|
|
160 |
temp_file_path = f"/tmp/{file.filename}"
|
161 |
with open(temp_file_path, "wb") as buffer:
|
162 |
buffer.write(await file.read())
|
163 |
-
|
164 |
# Upload the image to Google Cloud Storage
|
165 |
image_url = upload_image(temp_file_path)
|
166 |
-
|
167 |
# Remove the temporary file
|
168 |
os.remove(temp_file_path)
|
169 |
-
|
170 |
return {"image_url": image_url}
|
171 |
-
|
172 |
except Exception as e:
|
173 |
raise HTTPException(status_code=500, detail=str(e))
|
174 |
|
|
|
175 |
@app.get("/health")
|
176 |
async def health_check():
|
177 |
"""
|
@@ -179,6 +164,8 @@ async def health_check():
|
|
179 |
"""
|
180 |
return {"status": "ok"}
|
181 |
|
|
|
182 |
if __name__ == "__main__":
|
183 |
import uvicorn
|
184 |
-
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
+
from typing import Optional, List, Dict
|
7 |
import json
|
8 |
from pydantic import BaseModel
|
9 |
|
|
|
12 |
from src.agents.generic_agent import generic_agent
|
13 |
from src.hopter.client import Hopter, Environment
|
14 |
from src.services.generate_mask import GenerateMaskService
|
15 |
+
from src.utils import upload_image
|
16 |
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
|
|
28 |
allow_headers=["*"], # Allows all headers
|
29 |
)
|
30 |
|
31 |
+
|
32 |
class EditRequest(BaseModel):
|
33 |
edit_instruction: str
|
34 |
image_url: Optional[str] = None
|
35 |
|
36 |
+
|
37 |
class MessageContent(BaseModel):
|
38 |
type: str
|
39 |
text: Optional[str] = None
|
40 |
image_url: Optional[Dict[str, str]] = None
|
41 |
|
42 |
+
|
43 |
class Message(BaseModel):
|
44 |
content: List[MessageContent]
|
45 |
|
|
|
50 |
async with generic_agent.run_stream(query) as result:
|
51 |
async for message in result.stream(debounce_by=0.01):
|
52 |
yield json.dumps(message) + "\n"
|
53 |
+
|
54 |
return StreamingResponse(stream_messages(), media_type="text/plain")
|
55 |
|
56 |
+
|
57 |
@app.post("/edit")
|
58 |
async def edit_image(request: EditRequest):
|
59 |
"""
|
|
|
63 |
try:
|
64 |
# Initialize services
|
65 |
hopter = Hopter(
|
66 |
+
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
|
|
|
67 |
)
|
68 |
mask_service = GenerateMaskService(hopter=hopter)
|
69 |
|
|
|
72 |
edit_instruction=request.edit_instruction,
|
73 |
image_url=request.image_url,
|
74 |
hopter_client=hopter,
|
75 |
+
mask_service=mask_service,
|
76 |
)
|
77 |
|
78 |
# Create messages
|
79 |
+
messages = [{"type": "text", "text": request.edit_instruction}]
|
80 |
+
|
|
|
|
|
|
|
|
|
|
|
81 |
if request.image_url:
|
82 |
+
messages.append(
|
83 |
+
{"type": "image_url", "image_url": {"url": request.image_url}}
|
84 |
+
)
|
|
|
|
|
|
|
85 |
|
86 |
# Run the agent
|
87 |
result = await image_edit_agent.run(messages, deps=deps)
|
88 |
+
|
89 |
# Return the result
|
90 |
return {"edited_image_url": result.edited_image_url}
|
91 |
+
|
92 |
except Exception as e:
|
93 |
raise HTTPException(status_code=500, detail=str(e))
|
94 |
|
95 |
+
|
96 |
@app.post("/edit/stream")
|
97 |
async def edit_image_stream(request: EditRequest):
|
98 |
"""
|
|
|
102 |
try:
|
103 |
# Initialize services
|
104 |
hopter = Hopter(
|
105 |
+
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
|
|
|
106 |
)
|
107 |
mask_service = GenerateMaskService(hopter=hopter)
|
108 |
|
|
|
111 |
edit_instruction=request.edit_instruction,
|
112 |
image_url=request.image_url,
|
113 |
hopter_client=hopter,
|
114 |
+
mask_service=mask_service,
|
115 |
)
|
116 |
|
117 |
# Create messages
|
118 |
+
messages = [{"type": "text", "text": request.edit_instruction}]
|
119 |
+
|
|
|
|
|
|
|
|
|
|
|
120 |
if request.image_url:
|
121 |
+
messages.append(
|
122 |
+
{"type": "image_url", "image_url": {"url": request.image_url}}
|
123 |
+
)
|
|
|
|
|
|
|
124 |
|
125 |
async def stream_generator():
|
126 |
async with image_edit_agent.run_stream(messages, deps=deps) as result:
|
|
|
128 |
# Convert message to JSON and yield
|
129 |
yield json.dumps(message) + "\n"
|
130 |
|
131 |
+
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
|
132 |
+
|
|
|
|
|
|
|
133 |
except Exception as e:
|
134 |
raise HTTPException(status_code=500, detail=str(e))
|
135 |
|
136 |
+
|
137 |
@app.post("/upload")
|
138 |
async def upload_image_file(file: UploadFile = File(...)):
|
139 |
"""
|
|
|
144 |
temp_file_path = f"/tmp/{file.filename}"
|
145 |
with open(temp_file_path, "wb") as buffer:
|
146 |
buffer.write(await file.read())
|
147 |
+
|
148 |
# Upload the image to Google Cloud Storage
|
149 |
image_url = upload_image(temp_file_path)
|
150 |
+
|
151 |
# Remove the temporary file
|
152 |
os.remove(temp_file_path)
|
153 |
+
|
154 |
return {"image_url": image_url}
|
155 |
+
|
156 |
except Exception as e:
|
157 |
raise HTTPException(status_code=500, detail=str(e))
|
158 |
|
159 |
+
|
160 |
@app.get("/health")
|
161 |
async def health_check():
|
162 |
"""
|
|
|
164 |
"""
|
165 |
return {"status": "ok"}
|
166 |
|
167 |
+
|
168 |
if __name__ == "__main__":
|
169 |
import uvicorn
|
170 |
+
|
171 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/agents/generic_agent.py
CHANGED
@@ -1,14 +1,11 @@
|
|
1 |
-
from pydantic_ai import Agent
|
2 |
from pydantic_ai.models.openai import OpenAIModel
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
5 |
|
6 |
load_dotenv()
|
7 |
|
8 |
-
model = OpenAIModel(
|
9 |
-
"gpt-4o",
|
10 |
-
api_key=os.environ.get("OPENAI_API_KEY")
|
11 |
-
)
|
12 |
|
13 |
system_prompt = """
|
14 |
You are a helpful assistant that can answer questions and help with tasks.
|
@@ -19,5 +16,3 @@ generic_agent = Agent(
|
|
19 |
system_prompt=system_prompt,
|
20 |
tools=[],
|
21 |
)
|
22 |
-
|
23 |
-
|
|
|
1 |
+
from pydantic_ai import Agent
|
2 |
from pydantic_ai.models.openai import OpenAIModel
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
5 |
|
6 |
load_dotenv()
|
7 |
|
8 |
+
model = OpenAIModel("gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
|
|
|
|
9 |
|
10 |
system_prompt = """
|
11 |
You are a helpful assistant that can answer questions and help with tasks.
|
|
|
16 |
system_prompt=system_prompt,
|
17 |
tools=[],
|
18 |
)
|
|
|
|
src/agents/image_edit_agent.py
CHANGED
@@ -7,7 +7,12 @@ from dataclasses import dataclass
|
|
7 |
from typing import Optional
|
8 |
import logfire
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
-
from src.hopter.client import
|
|
|
|
|
|
|
|
|
|
|
11 |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
12 |
import base64
|
13 |
import tempfile
|
@@ -23,6 +28,7 @@ if the edit instruction involved modifying parts of the image, please generate a
|
|
23 |
if images are not provided, ask the user to provide an image.
|
24 |
"""
|
25 |
|
|
|
26 |
@dataclass
|
27 |
class ImageEditDeps:
|
28 |
edit_instruction: str
|
@@ -30,6 +36,7 @@ class ImageEditDeps:
|
|
30 |
mask_service: GenerateMaskService
|
31 |
image_url: Optional[str] = None
|
32 |
|
|
|
33 |
model = OpenAIModel(
|
34 |
"gpt-4o",
|
35 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
@@ -40,11 +47,9 @@ model = OpenAIModel(
|
|
40 |
class EditImageResult:
|
41 |
edited_image_url: str
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
deps_type=ImageEditDeps
|
47 |
-
)
|
48 |
|
49 |
def upload_image_from_base64(base64_image: str) -> str:
|
50 |
image_format = base64_image.split(",")[0]
|
@@ -56,6 +61,7 @@ def upload_image_from_base64(base64_image: str) -> str:
|
|
56 |
f.write(image_data)
|
57 |
return upload_image(temp_filename)
|
58 |
|
|
|
59 |
@image_edit_agent.tool
|
60 |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
61 |
"""
|
@@ -75,15 +81,20 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
75 |
image_uri = download_image_to_data_uri(image_url)
|
76 |
|
77 |
# Generate mask
|
78 |
-
mask_instruction = mask_service.get_mask_generation_instruction(
|
|
|
|
|
79 |
mask = mask_service.generate_mask(mask_instruction, image_uri)
|
80 |
|
81 |
# Magic replace
|
82 |
-
input = MagicReplaceInput(
|
|
|
|
|
83 |
result = hopter_client.magic_replace(input)
|
84 |
uploaded_image = upload_image_from_base64(result.base64_image)
|
85 |
return EditImageResult(edited_image_url=uploaded_image)
|
86 |
|
|
|
87 |
@image_edit_agent.tool
|
88 |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
89 |
"""
|
@@ -94,31 +105,28 @@ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
94 |
|
95 |
image_uri = download_image_to_data_uri(image_url)
|
96 |
|
97 |
-
input = SuperResolutionInput(
|
|
|
|
|
98 |
result = hopter_client.super_resolution(input)
|
99 |
uploaded_image = upload_image_from_base64(result.scaled_image)
|
100 |
return EditImageResult(edited_image_url=uploaded_image)
|
101 |
|
|
|
102 |
async def main():
|
103 |
image_file_path = "./assets/lakeview.jpg"
|
104 |
image_url = image_path_to_uri(image_file_path)
|
105 |
|
106 |
prompt = "remove the light post"
|
107 |
messages = [
|
108 |
-
{
|
109 |
-
|
110 |
-
"text": prompt
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"type": "image_url",
|
114 |
-
"image_url": {
|
115 |
-
"url": image_url
|
116 |
-
}
|
117 |
-
}
|
118 |
]
|
119 |
|
120 |
# Initialize services
|
121 |
-
hopter = Hopter(
|
|
|
|
|
122 |
mask_service = GenerateMaskService(hopter=hopter)
|
123 |
|
124 |
# Initialize dependencies
|
@@ -126,15 +134,12 @@ async def main():
|
|
126 |
edit_instruction=prompt,
|
127 |
image_url=image_url,
|
128 |
hopter_client=hopter,
|
129 |
-
mask_service=mask_service
|
130 |
)
|
131 |
-
async with image_edit_agent.run_stream(
|
132 |
-
messages,
|
133 |
-
deps=deps
|
134 |
-
) as result:
|
135 |
async for message in result.stream():
|
136 |
print(message)
|
137 |
|
138 |
|
139 |
if __name__ == "__main__":
|
140 |
-
asyncio.run(main())
|
|
|
7 |
from typing import Optional
|
8 |
import logfire
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
+
from src.hopter.client import (
|
11 |
+
Hopter,
|
12 |
+
Environment,
|
13 |
+
MagicReplaceInput,
|
14 |
+
SuperResolutionInput,
|
15 |
+
)
|
16 |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
17 |
import base64
|
18 |
import tempfile
|
|
|
28 |
if images are not provided, ask the user to provide an image.
|
29 |
"""
|
30 |
|
31 |
+
|
32 |
@dataclass
|
33 |
class ImageEditDeps:
|
34 |
edit_instruction: str
|
|
|
36 |
mask_service: GenerateMaskService
|
37 |
image_url: Optional[str] = None
|
38 |
|
39 |
+
|
40 |
model = OpenAIModel(
|
41 |
"gpt-4o",
|
42 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
|
|
47 |
class EditImageResult:
|
48 |
edited_image_url: str
|
49 |
|
50 |
+
|
51 |
+
image_edit_agent = Agent(model, system_prompt=system_prompt, deps_type=ImageEditDeps)
|
52 |
+
|
|
|
|
|
53 |
|
54 |
def upload_image_from_base64(base64_image: str) -> str:
|
55 |
image_format = base64_image.split(",")[0]
|
|
|
61 |
f.write(image_data)
|
62 |
return upload_image(temp_filename)
|
63 |
|
64 |
+
|
65 |
@image_edit_agent.tool
|
66 |
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
67 |
"""
|
|
|
81 |
image_uri = download_image_to_data_uri(image_url)
|
82 |
|
83 |
# Generate mask
|
84 |
+
mask_instruction = mask_service.get_mask_generation_instruction(
|
85 |
+
edit_instruction, image_url
|
86 |
+
)
|
87 |
mask = mask_service.generate_mask(mask_instruction, image_uri)
|
88 |
|
89 |
# Magic replace
|
90 |
+
input = MagicReplaceInput(
|
91 |
+
image=image_uri, mask=mask, prompt=mask_instruction.target_caption
|
92 |
+
)
|
93 |
result = hopter_client.magic_replace(input)
|
94 |
uploaded_image = upload_image_from_base64(result.base64_image)
|
95 |
return EditImageResult(edited_image_url=uploaded_image)
|
96 |
|
97 |
+
|
98 |
@image_edit_agent.tool
|
99 |
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
100 |
"""
|
|
|
105 |
|
106 |
image_uri = download_image_to_data_uri(image_url)
|
107 |
|
108 |
+
input = SuperResolutionInput(
|
109 |
+
image_b64=image_uri, scale=4, use_face_enhancement=False
|
110 |
+
)
|
111 |
result = hopter_client.super_resolution(input)
|
112 |
uploaded_image = upload_image_from_base64(result.scaled_image)
|
113 |
return EditImageResult(edited_image_url=uploaded_image)
|
114 |
|
115 |
+
|
116 |
async def main():
|
117 |
image_file_path = "./assets/lakeview.jpg"
|
118 |
image_url = image_path_to_uri(image_file_path)
|
119 |
|
120 |
prompt = "remove the light post"
|
121 |
messages = [
|
122 |
+
{"type": "text", "text": prompt},
|
123 |
+
{"type": "image_url", "image_url": {"url": image_url}},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
]
|
125 |
|
126 |
# Initialize services
|
127 |
+
hopter = Hopter(
|
128 |
+
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
|
129 |
+
)
|
130 |
mask_service = GenerateMaskService(hopter=hopter)
|
131 |
|
132 |
# Initialize dependencies
|
|
|
134 |
edit_instruction=prompt,
|
135 |
image_url=image_url,
|
136 |
hopter_client=hopter,
|
137 |
+
mask_service=mask_service,
|
138 |
)
|
139 |
+
async with image_edit_agent.run_stream(messages, deps=deps) as result:
|
|
|
|
|
|
|
140 |
async for message in result.stream():
|
141 |
print(message)
|
142 |
|
143 |
|
144 |
if __name__ == "__main__":
|
145 |
+
asyncio.run(main())
|
src/agents/mask_generation_agent.py
CHANGED
@@ -1,16 +1,9 @@
|
|
1 |
-
from pydantic_ai import Agent
|
2 |
from pydantic_ai.models.openai import OpenAIModel
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
5 |
-
import asyncio
|
6 |
from dataclasses import dataclass
|
7 |
-
from typing import Optional
|
8 |
import logfire
|
9 |
-
from src.services.generate_mask import GenerateMaskService
|
10 |
-
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
11 |
-
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
12 |
-
import base64
|
13 |
-
import tempfile
|
14 |
|
15 |
load_dotenv()
|
16 |
|
@@ -56,10 +49,9 @@ model = OpenAIModel(
|
|
56 |
class MaskGenerationResult:
|
57 |
mask_image_base64: str
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
)
|
63 |
|
64 |
@mask_generation_agent.tool
|
65 |
async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
|
|
|
1 |
+
from pydantic_ai import Agent
|
2 |
from pydantic_ai.models.openai import OpenAIModel
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
|
|
5 |
from dataclasses import dataclass
|
|
|
6 |
import logfire
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
|
|
49 |
class MaskGenerationResult:
|
50 |
mask_image_base64: str
|
51 |
|
52 |
+
|
53 |
+
mask_generation_agent = Agent(model, system_prompt=system_prompt)
|
54 |
+
|
|
|
55 |
|
56 |
@mask_generation_agent.tool
|
57 |
async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
|
src/hopter/client.py
CHANGED
@@ -9,6 +9,7 @@ from typing import List
|
|
9 |
|
10 |
load_dotenv()
|
11 |
|
|
|
12 |
class Environment(Enum):
|
13 |
STAGING = "staging"
|
14 |
PRODUCTION = "production"
|
@@ -20,39 +21,50 @@ class Environment(Enum):
|
|
20 |
return "https://serving.hopter.staging.picc.co"
|
21 |
case Environment.PRODUCTION:
|
22 |
return "https://serving.hopter.picc.co"
|
23 |
-
|
|
|
24 |
class RamGroundedSamInput(BaseModel):
|
25 |
-
text_prompt: str = Field(
|
|
|
|
|
26 |
image_b64: str = Field(..., description="The image in base64 format.")
|
27 |
|
|
|
28 |
class RamGroundedSamResult(BaseModel):
|
29 |
mask_b64: str = Field(..., description="The mask image in base64 format.")
|
30 |
class_label: str = Field(..., description="The class label of the mask.")
|
31 |
confidence: float = Field(..., description="The confidence score of the mask.")
|
32 |
-
bbox: List[float] = Field(
|
|
|
|
|
|
|
33 |
|
34 |
class MagicReplaceInput(BaseModel):
|
35 |
image: str = Field(..., description="The image in base64 format.")
|
36 |
mask: str = Field(..., description="The mask in base64 format.")
|
37 |
prompt: str = Field(..., description="The prompt for the magic replace.")
|
38 |
|
|
|
39 |
class MagicReplaceResult(BaseModel):
|
40 |
base64_image: str = Field(..., description="The edited image in base64 format.")
|
41 |
|
|
|
42 |
class SuperResolutionInput(BaseModel):
|
43 |
image_b64: str = Field(..., description="The image in base64 format.")
|
44 |
scale: int = Field(4, description="The scale of the image to upscale to.")
|
45 |
-
use_face_enhancement: bool = Field(
|
|
|
|
|
|
|
46 |
|
47 |
class SuperResolutionResult(BaseModel):
|
48 |
-
scaled_image: str = Field(
|
|
|
|
|
|
|
49 |
|
50 |
class Hopter:
|
51 |
-
def __init__(
|
52 |
-
self,
|
53 |
-
api_key: str,
|
54 |
-
environment: Environment = Environment.PRODUCTION
|
55 |
-
):
|
56 |
self.api_key = api_key
|
57 |
self.base_url = environment.base_url
|
58 |
self.client = httpx.Client()
|
@@ -64,22 +76,22 @@ class Hopter:
|
|
64 |
f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
|
65 |
headers={
|
66 |
"Authorization": f"Bearer {self.api_key}",
|
67 |
-
"Content-Type": "application/json"
|
68 |
-
},
|
69 |
-
json={
|
70 |
-
"input": input.model_dump()
|
71 |
},
|
72 |
-
|
|
|
73 |
)
|
74 |
response.raise_for_status() # Raise an error for bad responses
|
75 |
instance = response.json().get("output").get("instances")[0]
|
76 |
print("Generated mask.")
|
77 |
return RamGroundedSamResult(**instance)
|
78 |
except httpx.HTTPStatusError as exc:
|
79 |
-
print(
|
|
|
|
|
80 |
except Exception as exc:
|
81 |
print(f"An unexpected error occurred: {exc}")
|
82 |
-
|
83 |
def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
|
84 |
print(f"Magic replacing with input: {input.prompt}")
|
85 |
try:
|
@@ -87,19 +99,19 @@ class Hopter:
|
|
87 |
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
|
88 |
headers={
|
89 |
"Authorization": f"Bearer {self.api_key}",
|
90 |
-
"Content-Type": "application/json"
|
91 |
},
|
92 |
-
json={
|
93 |
-
|
94 |
-
},
|
95 |
-
timeout=None
|
96 |
)
|
97 |
response.raise_for_status() # Raise an error for bad responses
|
98 |
instance = response.json().get("output")
|
99 |
print("Magic replaced.")
|
100 |
return MagicReplaceResult(**instance)
|
101 |
except httpx.HTTPStatusError as exc:
|
102 |
-
print(
|
|
|
|
|
103 |
except Exception as exc:
|
104 |
print(f"An unexpected error occurred: {exc}")
|
105 |
|
@@ -109,51 +121,50 @@ class Hopter:
|
|
109 |
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
|
110 |
headers={
|
111 |
"Authorization": f"Bearer {self.api_key}",
|
112 |
-
"Content-Type": "application/json"
|
113 |
},
|
114 |
-
json={
|
115 |
-
|
116 |
-
},
|
117 |
-
timeout=None
|
118 |
)
|
119 |
response.raise_for_status() # Raise an error for bad responses
|
120 |
instance = response.json().get("output")
|
121 |
print("Super-resolutin done")
|
122 |
return SuperResolutionResult(**instance)
|
123 |
except httpx.HTTPStatusError as exc:
|
124 |
-
print(
|
|
|
|
|
125 |
except Exception as exc:
|
126 |
print(f"An unexpected error occurred: {exc}")
|
127 |
|
128 |
|
129 |
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
|
130 |
-
input = RamGroundedSamInput(
|
131 |
-
text_prompt="pole",
|
132 |
-
image_b64=image_url
|
133 |
-
)
|
134 |
mask = hopter.generate_mask(input)
|
135 |
return mask.mask_b64
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
)
|
143 |
result = hopter.magic_replace(input)
|
144 |
return result.base64_image
|
145 |
|
|
|
146 |
async def main():
|
147 |
hopter = Hopter(
|
148 |
-
api_key=os.getenv("HOPTER_API_KEY"),
|
149 |
-
environment=Environment.STAGING
|
150 |
)
|
151 |
image_file_path = "./assets/lakeview.jpg"
|
152 |
image_url = image_path_to_uri(image_file_path)
|
153 |
|
154 |
mask = await test_generate_mask(hopter, image_url)
|
155 |
-
magic_replace_result = await test_magic_replace(
|
|
|
|
|
156 |
print(magic_replace_result)
|
157 |
|
|
|
158 |
if __name__ == "__main__":
|
159 |
asyncio.run(main())
|
|
|
9 |
|
10 |
load_dotenv()
|
11 |
|
12 |
+
|
13 |
class Environment(Enum):
|
14 |
STAGING = "staging"
|
15 |
PRODUCTION = "production"
|
|
|
21 |
return "https://serving.hopter.staging.picc.co"
|
22 |
case Environment.PRODUCTION:
|
23 |
return "https://serving.hopter.picc.co"
|
24 |
+
|
25 |
+
|
26 |
class RamGroundedSamInput(BaseModel):
|
27 |
+
text_prompt: str = Field(
|
28 |
+
..., description="The text prompt for the mask generation."
|
29 |
+
)
|
30 |
image_b64: str = Field(..., description="The image in base64 format.")
|
31 |
|
32 |
+
|
33 |
class RamGroundedSamResult(BaseModel):
|
34 |
mask_b64: str = Field(..., description="The mask image in base64 format.")
|
35 |
class_label: str = Field(..., description="The class label of the mask.")
|
36 |
confidence: float = Field(..., description="The confidence score of the mask.")
|
37 |
+
bbox: List[float] = Field(
|
38 |
+
..., description="The bounding box of the mask in the format [x1, y1, x2, y2]."
|
39 |
+
)
|
40 |
+
|
41 |
|
42 |
class MagicReplaceInput(BaseModel):
|
43 |
image: str = Field(..., description="The image in base64 format.")
|
44 |
mask: str = Field(..., description="The mask in base64 format.")
|
45 |
prompt: str = Field(..., description="The prompt for the magic replace.")
|
46 |
|
47 |
+
|
48 |
class MagicReplaceResult(BaseModel):
|
49 |
base64_image: str = Field(..., description="The edited image in base64 format.")
|
50 |
|
51 |
+
|
52 |
class SuperResolutionInput(BaseModel):
|
53 |
image_b64: str = Field(..., description="The image in base64 format.")
|
54 |
scale: int = Field(4, description="The scale of the image to upscale to.")
|
55 |
+
use_face_enhancement: bool = Field(
|
56 |
+
False, description="Whether to use face enhancement."
|
57 |
+
)
|
58 |
+
|
59 |
|
60 |
class SuperResolutionResult(BaseModel):
|
61 |
+
scaled_image: str = Field(
|
62 |
+
..., description="The super-resolved image in base64 format."
|
63 |
+
)
|
64 |
+
|
65 |
|
66 |
class Hopter:
|
67 |
+
def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION):
|
|
|
|
|
|
|
|
|
68 |
self.api_key = api_key
|
69 |
self.base_url = environment.base_url
|
70 |
self.client = httpx.Client()
|
|
|
76 |
f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
|
77 |
headers={
|
78 |
"Authorization": f"Bearer {self.api_key}",
|
79 |
+
"Content-Type": "application/json",
|
|
|
|
|
|
|
80 |
},
|
81 |
+
json={"input": input.model_dump()},
|
82 |
+
timeout=None,
|
83 |
)
|
84 |
response.raise_for_status() # Raise an error for bad responses
|
85 |
instance = response.json().get("output").get("instances")[0]
|
86 |
print("Generated mask.")
|
87 |
return RamGroundedSamResult(**instance)
|
88 |
except httpx.HTTPStatusError as exc:
|
89 |
+
print(
|
90 |
+
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
|
91 |
+
)
|
92 |
except Exception as exc:
|
93 |
print(f"An unexpected error occurred: {exc}")
|
94 |
+
|
95 |
def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
|
96 |
print(f"Magic replacing with input: {input.prompt}")
|
97 |
try:
|
|
|
99 |
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
|
100 |
headers={
|
101 |
"Authorization": f"Bearer {self.api_key}",
|
102 |
+
"Content-Type": "application/json",
|
103 |
},
|
104 |
+
json={"input": input.model_dump()},
|
105 |
+
timeout=None,
|
|
|
|
|
106 |
)
|
107 |
response.raise_for_status() # Raise an error for bad responses
|
108 |
instance = response.json().get("output")
|
109 |
print("Magic replaced.")
|
110 |
return MagicReplaceResult(**instance)
|
111 |
except httpx.HTTPStatusError as exc:
|
112 |
+
print(
|
113 |
+
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
|
114 |
+
)
|
115 |
except Exception as exc:
|
116 |
print(f"An unexpected error occurred: {exc}")
|
117 |
|
|
|
121 |
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
|
122 |
headers={
|
123 |
"Authorization": f"Bearer {self.api_key}",
|
124 |
+
"Content-Type": "application/json",
|
125 |
},
|
126 |
+
json={"input": input.model_dump()},
|
127 |
+
timeout=None,
|
|
|
|
|
128 |
)
|
129 |
response.raise_for_status() # Raise an error for bad responses
|
130 |
instance = response.json().get("output")
|
131 |
print("Super-resolutin done")
|
132 |
return SuperResolutionResult(**instance)
|
133 |
except httpx.HTTPStatusError as exc:
|
134 |
+
print(
|
135 |
+
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
|
136 |
+
)
|
137 |
except Exception as exc:
|
138 |
print(f"An unexpected error occurred: {exc}")
|
139 |
|
140 |
|
141 |
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
|
142 |
+
input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url)
|
|
|
|
|
|
|
143 |
mask = hopter.generate_mask(input)
|
144 |
return mask.mask_b64
|
145 |
+
|
146 |
+
|
147 |
+
async def test_magic_replace(
|
148 |
+
hopter: Hopter, image_url: str, mask: str, prompt: str
|
149 |
+
) -> str:
|
150 |
+
input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt)
|
|
|
151 |
result = hopter.magic_replace(input)
|
152 |
return result.base64_image
|
153 |
|
154 |
+
|
155 |
async def main():
|
156 |
hopter = Hopter(
|
157 |
+
api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING
|
|
|
158 |
)
|
159 |
image_file_path = "./assets/lakeview.jpg"
|
160 |
image_url = image_path_to_uri(image_file_path)
|
161 |
|
162 |
mask = await test_generate_mask(hopter, image_url)
|
163 |
+
magic_replace_result = await test_magic_replace(
|
164 |
+
hopter, image_url, mask, "remove the pole"
|
165 |
+
)
|
166 |
print(magic_replace_result)
|
167 |
|
168 |
+
|
169 |
if __name__ == "__main__":
|
170 |
asyncio.run(main())
|
src/models/generate_mask_instruction.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
|
|
|
3 |
class GenerateMaskInstruction(BaseModel):
|
4 |
category: str = Field(
|
5 |
...,
|
6 |
-
description="The editing category based on the instruction. Must be one of: Addition, Remove, Local, Global, Background."
|
7 |
)
|
8 |
subject: str = Field(
|
9 |
...,
|
10 |
-
description="The subject of the editing instruction. Must be a noun in no more than 5 words."
|
11 |
-
)
|
12 |
-
caption: str = Field(
|
13 |
-
...,
|
14 |
-
description="The detailed description of the image."
|
15 |
)
|
|
|
16 |
target_caption: str = Field(
|
17 |
...,
|
18 |
-
description="Apply the editing instruction to the image caption. The target caption should describe the image after the editing instruction is applied."
|
19 |
-
)
|
|
|
1 |
from pydantic import BaseModel, Field
|
2 |
|
3 |
+
|
4 |
class GenerateMaskInstruction(BaseModel):
|
5 |
category: str = Field(
|
6 |
...,
|
7 |
+
description="The editing category based on the instruction. Must be one of: Addition, Remove, Local, Global, Background.",
|
8 |
)
|
9 |
subject: str = Field(
|
10 |
...,
|
11 |
+
description="The subject of the editing instruction. Must be a noun in no more than 5 words.",
|
|
|
|
|
|
|
|
|
12 |
)
|
13 |
+
caption: str = Field(..., description="The detailed description of the image.")
|
14 |
target_caption: str = Field(
|
15 |
...,
|
16 |
+
description="Apply the editing instruction to the image caption. The target caption should describe the image after the editing instruction is applied.",
|
17 |
+
)
|
src/services/generate_mask.py
CHANGED
@@ -6,6 +6,7 @@ from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
|
6 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
7 |
from src.services.openai_file_upload import OpenAIFileUpload
|
8 |
from src.utils import download_image_to_data_uri
|
|
|
9 |
load_dotenv()
|
10 |
|
11 |
system_prompt = """
|
@@ -37,6 +38,7 @@ Do not output 'sorry, xxx', even if it's a guess, directly output the answer you
|
|
37 |
</task_3>
|
38 |
"""
|
39 |
|
|
|
40 |
class GenerateMaskService:
|
41 |
def __init__(self, hopter: Hopter):
|
42 |
self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
@@ -44,38 +46,29 @@ class GenerateMaskService:
|
|
44 |
self.openai_file_upload = OpenAIFileUpload()
|
45 |
self.hopter = hopter
|
46 |
|
47 |
-
def get_mask_generation_instruction(
|
|
|
|
|
48 |
messages = [
|
49 |
-
{
|
50 |
-
"role": "system",
|
51 |
-
"content": system_prompt
|
52 |
-
},
|
53 |
{
|
54 |
"role": "user",
|
55 |
"content": [
|
56 |
-
{
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
{
|
61 |
-
"type": "image_url",
|
62 |
-
"image_url": {
|
63 |
-
"url": image_url
|
64 |
-
}
|
65 |
-
}
|
66 |
-
]
|
67 |
-
}
|
68 |
]
|
69 |
|
70 |
response = self.llm.beta.chat.completions.parse(
|
71 |
-
model=self.model,
|
72 |
-
messages=messages,
|
73 |
-
response_format=GenerateMaskInstruction
|
74 |
)
|
75 |
instruction = response.choices[0].message.parsed
|
76 |
return instruction
|
77 |
-
|
78 |
-
def generate_mask(
|
|
|
|
|
79 |
"""
|
80 |
Generate a mask for the image editing instruction.
|
81 |
|
@@ -87,14 +80,18 @@ class GenerateMaskService:
|
|
87 |
"""
|
88 |
image_uri = download_image_to_data_uri(image_url)
|
89 |
input = RamGroundedSamInput(
|
90 |
-
text_prompt=mask_instruction.subject,
|
91 |
-
image_b64=image_uri
|
92 |
)
|
93 |
generate_mask_result = self.hopter.generate_mask(input)
|
94 |
return generate_mask_result.mask_b64
|
95 |
-
|
|
|
96 |
async def main():
|
97 |
-
service = GenerateMaskService(
|
|
|
|
|
|
|
|
|
98 |
edit_instruction = "remove the light post"
|
99 |
image_file_path = "./assets/lakeview.jpg"
|
100 |
with open(image_file_path, "rb") as image_file:
|
@@ -105,5 +102,6 @@ async def main():
|
|
105 |
mask = service.generate_mask(instruction, image_url)
|
106 |
print(mask)
|
107 |
|
|
|
108 |
if __name__ == "__main__":
|
109 |
asyncio.run(main())
|
|
|
6 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
7 |
from src.services.openai_file_upload import OpenAIFileUpload
|
8 |
from src.utils import download_image_to_data_uri
|
9 |
+
|
10 |
load_dotenv()
|
11 |
|
12 |
system_prompt = """
|
|
|
38 |
</task_3>
|
39 |
"""
|
40 |
|
41 |
+
|
42 |
class GenerateMaskService:
|
43 |
def __init__(self, hopter: Hopter):
|
44 |
self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
46 |
self.openai_file_upload = OpenAIFileUpload()
|
47 |
self.hopter = hopter
|
48 |
|
49 |
+
def get_mask_generation_instruction(
|
50 |
+
self, edit_instruction: str, image_url: str
|
51 |
+
) -> GenerateMaskInstruction:
|
52 |
messages = [
|
53 |
+
{"role": "system", "content": system_prompt},
|
|
|
|
|
|
|
54 |
{
|
55 |
"role": "user",
|
56 |
"content": [
|
57 |
+
{"type": "text", "text": edit_instruction},
|
58 |
+
{"type": "image_url", "image_url": {"url": image_url}},
|
59 |
+
],
|
60 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
]
|
62 |
|
63 |
response = self.llm.beta.chat.completions.parse(
|
64 |
+
model=self.model, messages=messages, response_format=GenerateMaskInstruction
|
|
|
|
|
65 |
)
|
66 |
instruction = response.choices[0].message.parsed
|
67 |
return instruction
|
68 |
+
|
69 |
+
def generate_mask(
|
70 |
+
self, mask_instruction: GenerateMaskInstruction, image_url: str
|
71 |
+
) -> str:
|
72 |
"""
|
73 |
Generate a mask for the image editing instruction.
|
74 |
|
|
|
80 |
"""
|
81 |
image_uri = download_image_to_data_uri(image_url)
|
82 |
input = RamGroundedSamInput(
|
83 |
+
text_prompt=mask_instruction.subject, image_b64=image_uri
|
|
|
84 |
)
|
85 |
generate_mask_result = self.hopter.generate_mask(input)
|
86 |
return generate_mask_result.mask_b64
|
87 |
+
|
88 |
+
|
89 |
async def main():
|
90 |
+
service = GenerateMaskService(
|
91 |
+
Hopter(
|
92 |
+
api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING
|
93 |
+
)
|
94 |
+
)
|
95 |
edit_instruction = "remove the light post"
|
96 |
image_file_path = "./assets/lakeview.jpg"
|
97 |
with open(image_file_path, "rb") as image_file:
|
|
|
102 |
mask = service.generate_mask(instruction, image_url)
|
103 |
print(mask)
|
104 |
|
105 |
+
|
106 |
if __name__ == "__main__":
|
107 |
asyncio.run(main())
|
src/services/google_cloud_image_upload.py
CHANGED
@@ -7,14 +7,18 @@ from dotenv import load_dotenv
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
|
|
10 |
def get_credentials():
|
11 |
credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
12 |
# create a temp file with the credentials
|
13 |
-
with tempfile.NamedTemporaryFile(
|
|
|
|
|
14 |
temp_file.write(credentials_json_string)
|
15 |
temp_file_path = temp_file.name
|
16 |
return temp_file_path
|
17 |
|
|
|
18 |
class GoogleCloudImageUploadService:
|
19 |
BUCKET_NAME = "picchat-assets"
|
20 |
MAX_DIMENSION = 1024
|
@@ -39,37 +43,49 @@ class GoogleCloudImageUploadService:
|
|
39 |
# Open and optionally resize the image, then save to a temporary file.
|
40 |
with Image.open(source_file_name) as image:
|
41 |
# Determine the original format. If it's not JPEG or PNG, default to JPEG.
|
42 |
-
original_format =
|
|
|
|
|
43 |
|
44 |
# Resize if needed.
|
45 |
-
if
|
|
|
|
|
|
|
46 |
image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
|
47 |
-
|
48 |
# Choose the file extension based on the image format.
|
49 |
suffix = ".jpg" if original_format == "JPEG" else ".png"
|
50 |
-
|
51 |
# Create a temporary file with the appropriate suffix.
|
52 |
-
with tempfile.NamedTemporaryFile(
|
|
|
|
|
53 |
temp_filename = temp_file.name
|
54 |
image.save(temp_filename, format=original_format)
|
55 |
|
56 |
try:
|
57 |
# Set content type based on the image format.
|
58 |
-
content_type =
|
|
|
|
|
59 |
blob.upload_from_filename(temp_filename, content_type=content_type)
|
60 |
blob.make_public()
|
61 |
finally:
|
62 |
# Remove the temporary file.
|
63 |
os.remove(temp_filename)
|
64 |
|
65 |
-
print(
|
|
|
|
|
66 |
return blob.public_url
|
67 |
except Exception as e:
|
68 |
print(f"An error occurred: {e}")
|
69 |
return None
|
70 |
|
|
|
71 |
if __name__ == "__main__":
|
72 |
image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
|
73 |
upload_service = GoogleCloudImageUploadService()
|
74 |
url = upload_service.upload_image_to_gcs(image)
|
75 |
-
print(url)
|
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
10 |
+
|
11 |
def get_credentials():
|
12 |
credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
13 |
# create a temp file with the credentials
|
14 |
+
with tempfile.NamedTemporaryFile(
|
15 |
+
mode="w+", delete=False, suffix=".json"
|
16 |
+
) as temp_file:
|
17 |
temp_file.write(credentials_json_string)
|
18 |
temp_file_path = temp_file.name
|
19 |
return temp_file_path
|
20 |
|
21 |
+
|
22 |
class GoogleCloudImageUploadService:
|
23 |
BUCKET_NAME = "picchat-assets"
|
24 |
MAX_DIMENSION = 1024
|
|
|
43 |
# Open and optionally resize the image, then save to a temporary file.
|
44 |
with Image.open(source_file_name) as image:
|
45 |
# Determine the original format. If it's not JPEG or PNG, default to JPEG.
|
46 |
+
original_format = (
|
47 |
+
image.format.upper() if image.format in ["JPEG", "PNG"] else "JPEG"
|
48 |
+
)
|
49 |
|
50 |
# Resize if needed.
|
51 |
+
if (
|
52 |
+
image.width > self.MAX_DIMENSION
|
53 |
+
or image.height > self.MAX_DIMENSION
|
54 |
+
):
|
55 |
image.thumbnail((self.MAX_DIMENSION, self.MAX_DIMENSION))
|
56 |
+
|
57 |
# Choose the file extension based on the image format.
|
58 |
suffix = ".jpg" if original_format == "JPEG" else ".png"
|
59 |
+
|
60 |
# Create a temporary file with the appropriate suffix.
|
61 |
+
with tempfile.NamedTemporaryFile(
|
62 |
+
delete=False, suffix=suffix
|
63 |
+
) as temp_file:
|
64 |
temp_filename = temp_file.name
|
65 |
image.save(temp_filename, format=original_format)
|
66 |
|
67 |
try:
|
68 |
# Set content type based on the image format.
|
69 |
+
content_type = (
|
70 |
+
"image/jpeg" if original_format == "JPEG" else "image/png"
|
71 |
+
)
|
72 |
blob.upload_from_filename(temp_filename, content_type=content_type)
|
73 |
blob.make_public()
|
74 |
finally:
|
75 |
# Remove the temporary file.
|
76 |
os.remove(temp_filename)
|
77 |
|
78 |
+
print(
|
79 |
+
f"File {source_file_name} uploaded to {blob_name} in bucket {self.BUCKET_NAME}."
|
80 |
+
)
|
81 |
return blob.public_url
|
82 |
except Exception as e:
|
83 |
print(f"An error occurred: {e}")
|
84 |
return None
|
85 |
|
86 |
+
|
87 |
if __name__ == "__main__":
|
88 |
image = "./assets/lakeview.jpg" # Replace with your JPEG or PNG image path.
|
89 |
upload_service = GoogleCloudImageUploadService()
|
90 |
url = upload_service.upload_image_to_gcs(image)
|
91 |
+
print(url)
|
src/services/image_uploader.py
CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
5 |
import os
|
6 |
from pydantic import BaseModel
|
7 |
|
|
|
8 |
class ImageInfo(BaseModel):
|
9 |
filename: str
|
10 |
name: str
|
@@ -12,6 +13,7 @@ class ImageInfo(BaseModel):
|
|
12 |
extension: str
|
13 |
url: str
|
14 |
|
|
|
15 |
class ImgBBData(BaseModel):
|
16 |
id: str
|
17 |
title: str
|
@@ -28,33 +30,35 @@ class ImgBBData(BaseModel):
|
|
28 |
medium: ImageInfo
|
29 |
delete_url: str
|
30 |
|
|
|
31 |
class ImgBBResponse(BaseModel):
|
32 |
data: ImgBBData
|
33 |
success: bool
|
34 |
status: int
|
35 |
|
|
|
36 |
class ImageUploader:
|
37 |
"""A class to handle image uploads to ImgBB service."""
|
38 |
-
|
39 |
def __init__(self, api_key: str):
|
40 |
"""
|
41 |
Initialize the ImageUploader with an API key.
|
42 |
-
|
43 |
Args:
|
44 |
api_key (str): The ImgBB API key
|
45 |
"""
|
46 |
self.api_key = api_key
|
47 |
self.base_url = "https://api.imgbb.com/1/upload"
|
48 |
-
|
49 |
def upload(
|
50 |
self,
|
51 |
image: Union[str, bytes, Path],
|
52 |
name: Optional[str] = None,
|
53 |
-
expiration: Optional[int] = None
|
54 |
) -> ImgBBResponse:
|
55 |
"""
|
56 |
Upload an image to ImgBB.
|
57 |
-
|
58 |
Args:
|
59 |
image: Can be:
|
60 |
- A file path (str or Path)
|
@@ -64,20 +68,20 @@ class ImageUploader:
|
|
64 |
- Bytes of an image
|
65 |
name: Optional name for the uploaded file
|
66 |
expiration: Optional expiration time in seconds (60-15552000)
|
67 |
-
|
68 |
Returns:
|
69 |
ImgBBResponse containing the parsed upload response from ImgBB
|
70 |
-
|
71 |
Raises:
|
72 |
ValueError: If the image format is invalid or upload fails
|
73 |
requests.RequestException: If the API request fails
|
74 |
"""
|
75 |
# Prepare the parameters
|
76 |
-
params = {
|
77 |
if expiration:
|
78 |
if not 60 <= expiration <= 15552000:
|
79 |
raise ValueError("Expiration must be between 60 and 15552000 seconds")
|
80 |
-
params[
|
81 |
|
82 |
# Handle different image input types
|
83 |
if isinstance(image, (str, Path)):
|
@@ -85,38 +89,40 @@ class ImageUploader:
|
|
85 |
files = {}
|
86 |
if os.path.isfile(image_str):
|
87 |
# It's a file path
|
88 |
-
with open(image_str,
|
89 |
-
files[
|
90 |
-
elif image_str.startswith((
|
91 |
# It's a URL
|
92 |
-
files[
|
93 |
-
elif image_str.startswith(
|
94 |
# It's a data URI
|
95 |
# Extract the base64 part after the comma
|
96 |
-
base64_data = image_str.split(
|
97 |
-
files[
|
98 |
else:
|
99 |
# Assume it's base64 data
|
100 |
-
files[
|
101 |
|
102 |
if name:
|
103 |
-
files[
|
104 |
response = requests.post(self.base_url, params=params, files=files)
|
105 |
elif isinstance(image, bytes):
|
106 |
# Convert bytes to base64
|
107 |
-
base64_image = base64.b64encode(image).decode(
|
108 |
-
files = {
|
109 |
-
'image': (None, base64_image)
|
110 |
-
}
|
111 |
if name:
|
112 |
-
files[
|
113 |
response = requests.post(self.base_url, params=params, files=files)
|
114 |
else:
|
115 |
-
raise ValueError(
|
|
|
|
|
116 |
|
117 |
# Check the response
|
118 |
if response.status_code != 200:
|
119 |
-
raise ValueError(
|
|
|
|
|
120 |
|
121 |
# Parse the response using Pydantic model
|
122 |
response_json = response.json()
|
@@ -126,16 +132,16 @@ class ImageUploader:
|
|
126 |
self,
|
127 |
file_path: Union[str, Path],
|
128 |
name: Optional[str] = None,
|
129 |
-
expiration: Optional[int] = None
|
130 |
) -> ImgBBResponse:
|
131 |
"""
|
132 |
Convenience method to upload an image file.
|
133 |
-
|
134 |
Args:
|
135 |
file_path: Path to the image file
|
136 |
name: Optional name for the uploaded file
|
137 |
expiration: Optional expiration time in seconds (60-15552000)
|
138 |
-
|
139 |
Returns:
|
140 |
ImgBBResponse containing the parsed upload response from ImgBB
|
141 |
"""
|
@@ -145,16 +151,16 @@ class ImageUploader:
|
|
145 |
self,
|
146 |
image_url: str,
|
147 |
name: Optional[str] = None,
|
148 |
-
expiration: Optional[int] = None
|
149 |
) -> ImgBBResponse:
|
150 |
"""
|
151 |
Convenience method to upload an image from a URL.
|
152 |
-
|
153 |
Args:
|
154 |
image_url: URL of the image to upload
|
155 |
name: Optional name for the uploaded file
|
156 |
expiration: Optional expiration time in seconds (60-15552000)
|
157 |
-
|
158 |
Returns:
|
159 |
ImgBBResponse containing the parsed upload response from ImgBB
|
160 |
"""
|
|
|
5 |
import os
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
+
|
9 |
class ImageInfo(BaseModel):
|
10 |
filename: str
|
11 |
name: str
|
|
|
13 |
extension: str
|
14 |
url: str
|
15 |
|
16 |
+
|
17 |
class ImgBBData(BaseModel):
|
18 |
id: str
|
19 |
title: str
|
|
|
30 |
medium: ImageInfo
|
31 |
delete_url: str
|
32 |
|
33 |
+
|
34 |
class ImgBBResponse(BaseModel):
|
35 |
data: ImgBBData
|
36 |
success: bool
|
37 |
status: int
|
38 |
|
39 |
+
|
40 |
class ImageUploader:
|
41 |
"""A class to handle image uploads to ImgBB service."""
|
42 |
+
|
43 |
def __init__(self, api_key: str):
|
44 |
"""
|
45 |
Initialize the ImageUploader with an API key.
|
46 |
+
|
47 |
Args:
|
48 |
api_key (str): The ImgBB API key
|
49 |
"""
|
50 |
self.api_key = api_key
|
51 |
self.base_url = "https://api.imgbb.com/1/upload"
|
52 |
+
|
53 |
def upload(
|
54 |
self,
|
55 |
image: Union[str, bytes, Path],
|
56 |
name: Optional[str] = None,
|
57 |
+
expiration: Optional[int] = None,
|
58 |
) -> ImgBBResponse:
|
59 |
"""
|
60 |
Upload an image to ImgBB.
|
61 |
+
|
62 |
Args:
|
63 |
image: Can be:
|
64 |
- A file path (str or Path)
|
|
|
68 |
- Bytes of an image
|
69 |
name: Optional name for the uploaded file
|
70 |
expiration: Optional expiration time in seconds (60-15552000)
|
71 |
+
|
72 |
Returns:
|
73 |
ImgBBResponse containing the parsed upload response from ImgBB
|
74 |
+
|
75 |
Raises:
|
76 |
ValueError: If the image format is invalid or upload fails
|
77 |
requests.RequestException: If the API request fails
|
78 |
"""
|
79 |
# Prepare the parameters
|
80 |
+
params = {"key": self.api_key}
|
81 |
if expiration:
|
82 |
if not 60 <= expiration <= 15552000:
|
83 |
raise ValueError("Expiration must be between 60 and 15552000 seconds")
|
84 |
+
params["expiration"] = expiration
|
85 |
|
86 |
# Handle different image input types
|
87 |
if isinstance(image, (str, Path)):
|
|
|
89 |
files = {}
|
90 |
if os.path.isfile(image_str):
|
91 |
# It's a file path
|
92 |
+
with open(image_str, "rb") as file:
|
93 |
+
files["image"] = file
|
94 |
+
elif image_str.startswith(("http://", "https://")):
|
95 |
# It's a URL
|
96 |
+
files["image"] = (None, image_str)
|
97 |
+
elif image_str.startswith("data:image/"):
|
98 |
# It's a data URI
|
99 |
# Extract the base64 part after the comma
|
100 |
+
base64_data = image_str.split(",", 1)[1]
|
101 |
+
files["image"] = (None, base64_data)
|
102 |
else:
|
103 |
# Assume it's base64 data
|
104 |
+
files["image"] = (None, image_str)
|
105 |
|
106 |
if name:
|
107 |
+
files["name"] = (None, name)
|
108 |
response = requests.post(self.base_url, params=params, files=files)
|
109 |
elif isinstance(image, bytes):
|
110 |
# Convert bytes to base64
|
111 |
+
base64_image = base64.b64encode(image).decode("utf-8")
|
112 |
+
files = {"image": (None, base64_image)}
|
|
|
|
|
113 |
if name:
|
114 |
+
files["name"] = (None, name)
|
115 |
response = requests.post(self.base_url, params=params, files=files)
|
116 |
else:
|
117 |
+
raise ValueError(
|
118 |
+
"Invalid image format. Must be file path, URL, base64 string, or bytes"
|
119 |
+
)
|
120 |
|
121 |
# Check the response
|
122 |
if response.status_code != 200:
|
123 |
+
raise ValueError(
|
124 |
+
f"Upload failed with status {response.status_code}: {response.text}"
|
125 |
+
)
|
126 |
|
127 |
# Parse the response using Pydantic model
|
128 |
response_json = response.json()
|
|
|
132 |
self,
|
133 |
file_path: Union[str, Path],
|
134 |
name: Optional[str] = None,
|
135 |
+
expiration: Optional[int] = None,
|
136 |
) -> ImgBBResponse:
|
137 |
"""
|
138 |
Convenience method to upload an image file.
|
139 |
+
|
140 |
Args:
|
141 |
file_path: Path to the image file
|
142 |
name: Optional name for the uploaded file
|
143 |
expiration: Optional expiration time in seconds (60-15552000)
|
144 |
+
|
145 |
Returns:
|
146 |
ImgBBResponse containing the parsed upload response from ImgBB
|
147 |
"""
|
|
|
151 |
self,
|
152 |
image_url: str,
|
153 |
name: Optional[str] = None,
|
154 |
+
expiration: Optional[int] = None,
|
155 |
) -> ImgBBResponse:
|
156 |
"""
|
157 |
Convenience method to upload an image from a URL.
|
158 |
+
|
159 |
Args:
|
160 |
image_url: URL of the image to upload
|
161 |
name: Optional name for the uploaded file
|
162 |
expiration: Optional expiration time in seconds (60-15552000)
|
163 |
+
|
164 |
Returns:
|
165 |
ImgBBResponse containing the parsed upload response from ImgBB
|
166 |
"""
|
src/services/openai_file_upload.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
|
5 |
load_dotenv()
|
6 |
|
|
|
7 |
class OpenAIFileUpload:
|
8 |
def __init__(self):
|
9 |
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
4 |
|
5 |
load_dotenv()
|
6 |
|
7 |
+
|
8 |
class OpenAIFileUpload:
|
9 |
def __init__(self):
|
10 |
self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
src/utils.py
CHANGED
@@ -4,16 +4,21 @@ from src.services.google_cloud_image_upload import GoogleCloudImageUploadService
|
|
4 |
from PIL import Image
|
5 |
from urllib.request import urlopen
|
6 |
import io
|
|
|
|
|
7 |
def image_path_to_base64(image_path: str) -> str:
|
8 |
with open(image_path, "rb") as image_file:
|
9 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
10 |
|
|
|
11 |
def upload_file_to_base64(file: UploadFile) -> str:
|
12 |
return base64.b64encode(file.file.read()).decode("utf-8")
|
13 |
|
|
|
14 |
def image_path_to_uri(image_path: str) -> str:
|
15 |
return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
|
16 |
|
|
|
17 |
def upload_image(image_path: str) -> str:
|
18 |
"""
|
19 |
Upload an image to Google Cloud Storage and return the public URL.
|
@@ -27,6 +32,7 @@ def upload_image(image_path: str) -> str:
|
|
27 |
upload_service = GoogleCloudImageUploadService()
|
28 |
return upload_service.upload_image_to_gcs(image_path)
|
29 |
|
|
|
30 |
def download_image_to_data_uri(image_url: str) -> str:
|
31 |
# Open the image from the URL
|
32 |
response = urlopen(image_url)
|
@@ -34,16 +40,20 @@ def download_image_to_data_uri(image_url: str) -> str:
|
|
34 |
|
35 |
# Determine the image format; default to 'JPEG' if not found
|
36 |
image_format = img.format if img.format is not None else "JPEG"
|
37 |
-
|
38 |
# Build the MIME type; for 'JPEG', use 'image/jpeg'
|
39 |
-
mime_type =
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
# Save the image to an in-memory buffer using the detected format
|
42 |
buffered = io.BytesIO()
|
43 |
img.save(buffered, format=image_format)
|
44 |
-
|
45 |
# Encode the image bytes to base64
|
46 |
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
47 |
-
|
48 |
# Return the data URI with the correct MIME type
|
49 |
-
return f"data:{mime_type};base64,{img_base64}"
|
|
|
4 |
from PIL import Image
|
5 |
from urllib.request import urlopen
|
6 |
import io
|
7 |
+
|
8 |
+
|
9 |
def image_path_to_base64(image_path: str) -> str:
|
10 |
with open(image_path, "rb") as image_file:
|
11 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
12 |
|
13 |
+
|
14 |
def upload_file_to_base64(file: UploadFile) -> str:
|
15 |
return base64.b64encode(file.file.read()).decode("utf-8")
|
16 |
|
17 |
+
|
18 |
def image_path_to_uri(image_path: str) -> str:
|
19 |
return f"data:image/jpeg;base64,{image_path_to_base64(image_path)}"
|
20 |
|
21 |
+
|
22 |
def upload_image(image_path: str) -> str:
|
23 |
"""
|
24 |
Upload an image to Google Cloud Storage and return the public URL.
|
|
|
32 |
upload_service = GoogleCloudImageUploadService()
|
33 |
return upload_service.upload_image_to_gcs(image_path)
|
34 |
|
35 |
+
|
36 |
def download_image_to_data_uri(image_url: str) -> str:
|
37 |
# Open the image from the URL
|
38 |
response = urlopen(image_url)
|
|
|
40 |
|
41 |
# Determine the image format; default to 'JPEG' if not found
|
42 |
image_format = img.format if img.format is not None else "JPEG"
|
43 |
+
|
44 |
# Build the MIME type; for 'JPEG', use 'image/jpeg'
|
45 |
+
mime_type = (
|
46 |
+
"image/jpeg"
|
47 |
+
if image_format.upper() == "JPEG"
|
48 |
+
else f"image/{image_format.lower()}"
|
49 |
+
)
|
50 |
+
|
51 |
# Save the image to an in-memory buffer using the detected format
|
52 |
buffered = io.BytesIO()
|
53 |
img.save(buffered, format=image_format)
|
54 |
+
|
55 |
# Encode the image bytes to base64
|
56 |
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
57 |
+
|
58 |
# Return the data URI with the correct MIME type
|
59 |
+
return f"data:{mime_type};base64,{img_base64}"
|
stream_utils.py
CHANGED
@@ -6,28 +6,29 @@ from rich.markdown import Markdown
|
|
6 |
from rich.panel import Panel
|
7 |
from rich.text import Text
|
8 |
|
|
|
9 |
class StreamResponseHandler:
|
10 |
"""
|
11 |
A utility class for handling streaming responses from API endpoints.
|
12 |
Provides rich formatting and real-time updates of the response content.
|
13 |
"""
|
14 |
-
|
15 |
def __init__(self, console=None):
|
16 |
"""
|
17 |
Initialize the stream response handler.
|
18 |
-
|
19 |
Args:
|
20 |
console (Console, optional): A Rich console instance. If not provided, a new one will be created.
|
21 |
"""
|
22 |
self.console = console or Console()
|
23 |
-
|
24 |
def check_server_health(self, health_url="http://localhost:8000/health"):
|
25 |
"""
|
26 |
Check if the server is running and accessible.
|
27 |
-
|
28 |
Args:
|
29 |
health_url (str, optional): The URL to check server health. Defaults to "http://localhost:8000/health".
|
30 |
-
|
31 |
Returns:
|
32 |
bool: True if the server is running and accessible, False otherwise.
|
33 |
"""
|
@@ -38,26 +39,32 @@ class StreamResponseHandler:
|
|
38 |
self.console.print("[bold green]✓ Server is running and accessible.[/]")
|
39 |
return True
|
40 |
else:
|
41 |
-
self.console.print(
|
|
|
|
|
42 |
return False
|
43 |
except requests.exceptions.ConnectionError:
|
44 |
-
self.console.print(
|
|
|
|
|
45 |
return False
|
46 |
except Exception as e:
|
47 |
self.console.print(f"[bold red]✗ Error checking server health:[/] {e}")
|
48 |
return False
|
49 |
-
|
50 |
-
def stream_response(
|
|
|
|
|
51 |
"""
|
52 |
Send a request to an endpoint and stream the output to the terminal.
|
53 |
-
|
54 |
Args:
|
55 |
url (str): The URL of the endpoint to send the request to.
|
56 |
payload (dict, optional): The JSON payload to send in the request body. Defaults to None.
|
57 |
params (dict, optional): The query parameters to send in the request. Defaults to None.
|
58 |
method (str, optional): The HTTP method to use. Defaults to "POST".
|
59 |
title (str, optional): The title to display in the panel. Defaults to "AI Response".
|
60 |
-
|
61 |
Returns:
|
62 |
bool: True if the streaming was successful, False otherwise.
|
63 |
"""
|
@@ -69,39 +76,42 @@ class StreamResponseHandler:
|
|
69 |
if params:
|
70 |
self.console.print("Parameters:", style="bold")
|
71 |
self.console.print(json.dumps(params, indent=2))
|
72 |
-
|
73 |
try:
|
74 |
# Prepare the request
|
75 |
-
request_kwargs = {
|
76 |
-
"stream": True
|
77 |
-
}
|
78 |
if payload:
|
79 |
request_kwargs["json"] = payload
|
80 |
if params:
|
81 |
request_kwargs["params"] = params
|
82 |
-
|
83 |
# Make the request
|
84 |
with getattr(requests, method.lower())(url, **request_kwargs) as response:
|
85 |
# Check if the request was successful
|
86 |
if response.status_code != 200:
|
87 |
-
self.console.print(
|
|
|
|
|
88 |
self.console.print(f"Response: {response.text}")
|
89 |
return False
|
90 |
-
|
91 |
# Initialize an empty response text
|
92 |
full_response = ""
|
93 |
-
|
94 |
# Use Rich's Live display to update the content in place
|
95 |
-
with Live(
|
|
|
|
|
|
|
96 |
# Process the streaming response
|
97 |
for line in response.iter_lines():
|
98 |
if line:
|
99 |
# Decode the line and parse it as JSON
|
100 |
-
decoded_line = line.decode(
|
101 |
try:
|
102 |
# Parse the JSON
|
103 |
data = json.loads(decoded_line)
|
104 |
-
|
105 |
# Extract and display the content
|
106 |
if isinstance(data, dict):
|
107 |
if "content" in data:
|
@@ -111,35 +121,82 @@ class StreamResponseHandler:
|
|
111 |
# Append to the full response
|
112 |
full_response += text_content
|
113 |
# Update the live display with the current full response
|
114 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
elif content.get("type") == "image_url":
|
116 |
-
image_url = content.get(
|
|
|
|
|
117 |
# Add a note about the image URL
|
118 |
-
image_note =
|
|
|
|
|
119 |
full_response += image_note
|
120 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
elif "edited_image_url" in data:
|
122 |
# Handle edited image URL from edit endpoint
|
123 |
image_url = data.get("edited_image_url", "")
|
124 |
-
image_note =
|
|
|
|
|
125 |
full_response += image_note
|
126 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
else:
|
128 |
# For other types of data, just show the JSON
|
129 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
else:
|
131 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
except json.JSONDecodeError:
|
133 |
# If it's not valid JSON, just show the raw line
|
134 |
-
live.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
self.console.print("[bold green]Stream completed.[/]")
|
137 |
return True
|
138 |
-
|
139 |
except requests.exceptions.ConnectionError:
|
140 |
-
self.console.print(
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
return False
|
143 |
except requests.exceptions.RequestException as e:
|
144 |
self.console.print(f"[bold red]Error:[/] {e}", style="red")
|
145 |
-
return False
|
|
|
6 |
from rich.panel import Panel
|
7 |
from rich.text import Text
|
8 |
|
9 |
+
|
10 |
class StreamResponseHandler:
|
11 |
"""
|
12 |
A utility class for handling streaming responses from API endpoints.
|
13 |
Provides rich formatting and real-time updates of the response content.
|
14 |
"""
|
15 |
+
|
16 |
def __init__(self, console=None):
|
17 |
"""
|
18 |
Initialize the stream response handler.
|
19 |
+
|
20 |
Args:
|
21 |
console (Console, optional): A Rich console instance. If not provided, a new one will be created.
|
22 |
"""
|
23 |
self.console = console or Console()
|
24 |
+
|
25 |
def check_server_health(self, health_url="http://localhost:8000/health"):
|
26 |
"""
|
27 |
Check if the server is running and accessible.
|
28 |
+
|
29 |
Args:
|
30 |
health_url (str, optional): The URL to check server health. Defaults to "http://localhost:8000/health".
|
31 |
+
|
32 |
Returns:
|
33 |
bool: True if the server is running and accessible, False otherwise.
|
34 |
"""
|
|
|
39 |
self.console.print("[bold green]✓ Server is running and accessible.[/]")
|
40 |
return True
|
41 |
else:
|
42 |
+
self.console.print(
|
43 |
+
f"[bold red]✗ Server health check failed[/] with status code: {response.status_code}"
|
44 |
+
)
|
45 |
return False
|
46 |
except requests.exceptions.ConnectionError:
|
47 |
+
self.console.print(
|
48 |
+
"[bold red]✗ Error:[/] Could not connect to the server. Make sure it's running."
|
49 |
+
)
|
50 |
return False
|
51 |
except Exception as e:
|
52 |
self.console.print(f"[bold red]✗ Error checking server health:[/] {e}")
|
53 |
return False
|
54 |
+
|
55 |
+
def stream_response(
|
56 |
+
self, url, payload=None, params=None, method="POST", title="AI Response"
|
57 |
+
):
|
58 |
"""
|
59 |
Send a request to an endpoint and stream the output to the terminal.
|
60 |
+
|
61 |
Args:
|
62 |
url (str): The URL of the endpoint to send the request to.
|
63 |
payload (dict, optional): The JSON payload to send in the request body. Defaults to None.
|
64 |
params (dict, optional): The query parameters to send in the request. Defaults to None.
|
65 |
method (str, optional): The HTTP method to use. Defaults to "POST".
|
66 |
title (str, optional): The title to display in the panel. Defaults to "AI Response".
|
67 |
+
|
68 |
Returns:
|
69 |
bool: True if the streaming was successful, False otherwise.
|
70 |
"""
|
|
|
76 |
if params:
|
77 |
self.console.print("Parameters:", style="bold")
|
78 |
self.console.print(json.dumps(params, indent=2))
|
79 |
+
|
80 |
try:
|
81 |
# Prepare the request
|
82 |
+
request_kwargs = {"stream": True}
|
|
|
|
|
83 |
if payload:
|
84 |
request_kwargs["json"] = payload
|
85 |
if params:
|
86 |
request_kwargs["params"] = params
|
87 |
+
|
88 |
# Make the request
|
89 |
with getattr(requests, method.lower())(url, **request_kwargs) as response:
|
90 |
# Check if the request was successful
|
91 |
if response.status_code != 200:
|
92 |
+
self.console.print(
|
93 |
+
f"[bold red]Error:[/] Received status code {response.status_code}"
|
94 |
+
)
|
95 |
self.console.print(f"Response: {response.text}")
|
96 |
return False
|
97 |
+
|
98 |
# Initialize an empty response text
|
99 |
full_response = ""
|
100 |
+
|
101 |
# Use Rich's Live display to update the content in place
|
102 |
+
with Live(
|
103 |
+
Panel("Waiting for response...", title=title, border_style="blue"),
|
104 |
+
refresh_per_second=10,
|
105 |
+
) as live:
|
106 |
# Process the streaming response
|
107 |
for line in response.iter_lines():
|
108 |
if line:
|
109 |
# Decode the line and parse it as JSON
|
110 |
+
decoded_line = line.decode("utf-8")
|
111 |
try:
|
112 |
# Parse the JSON
|
113 |
data = json.loads(decoded_line)
|
114 |
+
|
115 |
# Extract and display the content
|
116 |
if isinstance(data, dict):
|
117 |
if "content" in data:
|
|
|
121 |
# Append to the full response
|
122 |
full_response += text_content
|
123 |
# Update the live display with the current full response
|
124 |
+
live.update(
|
125 |
+
Panel(
|
126 |
+
Markdown(full_response),
|
127 |
+
title=title,
|
128 |
+
border_style="green",
|
129 |
+
)
|
130 |
+
)
|
131 |
elif content.get("type") == "image_url":
|
132 |
+
image_url = content.get(
|
133 |
+
"image_url", {}
|
134 |
+
).get("url", "")
|
135 |
# Add a note about the image URL
|
136 |
+
image_note = (
|
137 |
+
f"\n\n[Image URL: {image_url}]"
|
138 |
+
)
|
139 |
full_response += image_note
|
140 |
+
live.update(
|
141 |
+
Panel(
|
142 |
+
Markdown(full_response),
|
143 |
+
title=title,
|
144 |
+
border_style="green",
|
145 |
+
)
|
146 |
+
)
|
147 |
elif "edited_image_url" in data:
|
148 |
# Handle edited image URL from edit endpoint
|
149 |
image_url = data.get("edited_image_url", "")
|
150 |
+
image_note = (
|
151 |
+
f"\n\n[Edited Image URL: {image_url}]"
|
152 |
+
)
|
153 |
full_response += image_note
|
154 |
+
live.update(
|
155 |
+
Panel(
|
156 |
+
Markdown(full_response),
|
157 |
+
title=title,
|
158 |
+
border_style="green",
|
159 |
+
)
|
160 |
+
)
|
161 |
else:
|
162 |
# For other types of data, just show the JSON
|
163 |
+
live.update(
|
164 |
+
Panel(
|
165 |
+
Text(json.dumps(data, indent=2)),
|
166 |
+
title="Raw JSON Response",
|
167 |
+
border_style="yellow",
|
168 |
+
)
|
169 |
+
)
|
170 |
else:
|
171 |
+
live.update(
|
172 |
+
Panel(
|
173 |
+
Text(decoded_line),
|
174 |
+
title="Raw Response",
|
175 |
+
border_style="yellow",
|
176 |
+
)
|
177 |
+
)
|
178 |
except json.JSONDecodeError:
|
179 |
# If it's not valid JSON, just show the raw line
|
180 |
+
live.update(
|
181 |
+
Panel(
|
182 |
+
Text(f"Raw response: {decoded_line}"),
|
183 |
+
title="Invalid JSON",
|
184 |
+
border_style="red",
|
185 |
+
)
|
186 |
+
)
|
187 |
|
188 |
self.console.print("[bold green]Stream completed.[/]")
|
189 |
return True
|
190 |
+
|
191 |
except requests.exceptions.ConnectionError:
|
192 |
+
self.console.print(
|
193 |
+
f"[bold red]Error:[/] Could not connect to the server at {url}",
|
194 |
+
style="red",
|
195 |
+
)
|
196 |
+
self.console.print(
|
197 |
+
"Make sure the server is running and accessible.", style="red"
|
198 |
+
)
|
199 |
return False
|
200 |
except requests.exceptions.RequestException as e:
|
201 |
self.console.print(f"[bold red]Error:[/] {e}", style="red")
|
202 |
+
return False
|
test_edit_stream.py
CHANGED
@@ -2,19 +2,20 @@ import argparse
|
|
2 |
import os
|
3 |
import sys
|
4 |
import requests
|
5 |
-
import json
|
6 |
from dotenv import load_dotenv
|
7 |
from stream_utils import StreamResponseHandler
|
8 |
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
11 |
|
|
|
12 |
def get_default_image():
|
13 |
"""Get the default image path and convert it to a data URI."""
|
14 |
image_path = "./assets/lakeview.jpg"
|
15 |
if os.path.exists(image_path):
|
16 |
try:
|
17 |
from src.utils import image_path_to_uri
|
|
|
18 |
image_uri = image_path_to_uri(image_path)
|
19 |
print(f"Using default image: {image_path}")
|
20 |
return image_uri
|
@@ -25,80 +26,93 @@ def get_default_image():
|
|
25 |
print(f"Warning: Default image not found at {image_path}")
|
26 |
return None
|
27 |
|
|
|
28 |
def upload_image(handler, image_path):
|
29 |
"""
|
30 |
Upload an image to the server.
|
31 |
-
|
32 |
Args:
|
33 |
handler (StreamResponseHandler): The stream response handler.
|
34 |
image_path (str): Path to the image file to upload.
|
35 |
-
|
36 |
Returns:
|
37 |
str: The URL of the uploaded image, or None if upload failed.
|
38 |
"""
|
39 |
if not os.path.exists(image_path):
|
40 |
-
handler.console.print(
|
|
|
|
|
41 |
return None
|
42 |
-
|
43 |
try:
|
44 |
handler.console.print(f"Uploading image: [bold]{image_path}[/]")
|
45 |
-
with open(image_path,
|
46 |
-
files = {
|
47 |
response = requests.post("http://localhost:8000/upload", files=files)
|
48 |
if response.status_code == 200:
|
49 |
image_url = response.json().get("image_url")
|
50 |
-
handler.console.print(
|
|
|
|
|
51 |
return image_url
|
52 |
else:
|
53 |
-
handler.console.print(
|
|
|
|
|
54 |
handler.console.print(f"Response: {response.text}")
|
55 |
return None
|
56 |
except Exception as e:
|
57 |
handler.console.print(f"[bold red]Error uploading image:[/] {e}")
|
58 |
return None
|
59 |
|
|
|
60 |
def main():
|
61 |
# Create a stream response handler
|
62 |
handler = StreamResponseHandler()
|
63 |
-
|
64 |
# Parse command line arguments
|
65 |
parser = argparse.ArgumentParser(description="Test the image edit streaming API.")
|
66 |
-
parser.add_argument(
|
|
|
|
|
67 |
parser.add_argument("--image", "-img", help="The URL of the image to edit.")
|
68 |
parser.add_argument("--upload", "-u", help="Path to an image file to upload first.")
|
69 |
-
|
70 |
args = parser.parse_args()
|
71 |
-
|
72 |
# Check if the server is running
|
73 |
if not handler.check_server_health():
|
74 |
sys.exit(1)
|
75 |
-
|
76 |
image_url = args.image
|
77 |
-
|
78 |
# If upload is specified, upload the image first
|
79 |
if args.upload:
|
80 |
image_url = upload_image(handler, args.upload)
|
81 |
if not image_url:
|
82 |
-
handler.console.print(
|
83 |
-
|
|
|
|
|
84 |
# Use the default image if no image URL is provided
|
85 |
if not image_url:
|
86 |
image_url = get_default_image()
|
87 |
if not image_url:
|
88 |
-
handler.console.print(
|
|
|
|
|
89 |
handler.console.print("The agent may ask for an image if needed.")
|
90 |
-
|
91 |
# Prepare the payload for the edit request
|
92 |
-
payload = {
|
93 |
-
|
94 |
-
}
|
95 |
-
|
96 |
if image_url:
|
97 |
payload["image_url"] = image_url
|
98 |
-
|
99 |
# Stream the edit request
|
100 |
endpoint_url = "http://localhost:8000/edit/stream"
|
101 |
handler.stream_response(endpoint_url, payload=payload, title="Image Edit Response")
|
102 |
|
|
|
103 |
if __name__ == "__main__":
|
104 |
-
main()
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
import requests
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
from stream_utils import StreamResponseHandler
|
7 |
|
8 |
# Load environment variables
|
9 |
load_dotenv()
|
10 |
|
11 |
+
|
12 |
def get_default_image():
|
13 |
"""Get the default image path and convert it to a data URI."""
|
14 |
image_path = "./assets/lakeview.jpg"
|
15 |
if os.path.exists(image_path):
|
16 |
try:
|
17 |
from src.utils import image_path_to_uri
|
18 |
+
|
19 |
image_uri = image_path_to_uri(image_path)
|
20 |
print(f"Using default image: {image_path}")
|
21 |
return image_uri
|
|
|
26 |
print(f"Warning: Default image not found at {image_path}")
|
27 |
return None
|
28 |
|
29 |
+
|
30 |
def upload_image(handler, image_path):
|
31 |
"""
|
32 |
Upload an image to the server.
|
33 |
+
|
34 |
Args:
|
35 |
handler (StreamResponseHandler): The stream response handler.
|
36 |
image_path (str): Path to the image file to upload.
|
37 |
+
|
38 |
Returns:
|
39 |
str: The URL of the uploaded image, or None if upload failed.
|
40 |
"""
|
41 |
if not os.path.exists(image_path):
|
42 |
+
handler.console.print(
|
43 |
+
f"[bold red]Error:[/] Image file not found at {image_path}"
|
44 |
+
)
|
45 |
return None
|
46 |
+
|
47 |
try:
|
48 |
handler.console.print(f"Uploading image: [bold]{image_path}[/]")
|
49 |
+
with open(image_path, "rb") as f:
|
50 |
+
files = {"file": (os.path.basename(image_path), f)}
|
51 |
response = requests.post("http://localhost:8000/upload", files=files)
|
52 |
if response.status_code == 200:
|
53 |
image_url = response.json().get("image_url")
|
54 |
+
handler.console.print(
|
55 |
+
f"Image uploaded successfully. URL: [bold green]{image_url}[/]"
|
56 |
+
)
|
57 |
return image_url
|
58 |
else:
|
59 |
+
handler.console.print(
|
60 |
+
f"[bold red]Failed to upload image.[/] Status code: {response.status_code}"
|
61 |
+
)
|
62 |
handler.console.print(f"Response: {response.text}")
|
63 |
return None
|
64 |
except Exception as e:
|
65 |
handler.console.print(f"[bold red]Error uploading image:[/] {e}")
|
66 |
return None
|
67 |
|
68 |
+
|
69 |
def main():
|
70 |
# Create a stream response handler
|
71 |
handler = StreamResponseHandler()
|
72 |
+
|
73 |
# Parse command line arguments
|
74 |
parser = argparse.ArgumentParser(description="Test the image edit streaming API.")
|
75 |
+
parser.add_argument(
|
76 |
+
"--instruction", "-i", required=True, help="The edit instruction."
|
77 |
+
)
|
78 |
parser.add_argument("--image", "-img", help="The URL of the image to edit.")
|
79 |
parser.add_argument("--upload", "-u", help="Path to an image file to upload first.")
|
80 |
+
|
81 |
args = parser.parse_args()
|
82 |
+
|
83 |
# Check if the server is running
|
84 |
if not handler.check_server_health():
|
85 |
sys.exit(1)
|
86 |
+
|
87 |
image_url = args.image
|
88 |
+
|
89 |
# If upload is specified, upload the image first
|
90 |
if args.upload:
|
91 |
image_url = upload_image(handler, args.upload)
|
92 |
if not image_url:
|
93 |
+
handler.console.print(
|
94 |
+
"[yellow]Warning:[/] Failed to upload image. Continuing without image URL."
|
95 |
+
)
|
96 |
+
|
97 |
# Use the default image if no image URL is provided
|
98 |
if not image_url:
|
99 |
image_url = get_default_image()
|
100 |
if not image_url:
|
101 |
+
handler.console.print(
|
102 |
+
"[yellow]No image URL provided and default image not available.[/]"
|
103 |
+
)
|
104 |
handler.console.print("The agent may ask for an image if needed.")
|
105 |
+
|
106 |
# Prepare the payload for the edit request
|
107 |
+
payload = {"edit_instruction": args.instruction}
|
108 |
+
|
|
|
|
|
109 |
if image_url:
|
110 |
payload["image_url"] = image_url
|
111 |
+
|
112 |
# Stream the edit request
|
113 |
endpoint_url = "http://localhost:8000/edit/stream"
|
114 |
handler.stream_response(endpoint_url, payload=payload, title="Image Edit Response")
|
115 |
|
116 |
+
|
117 |
if __name__ == "__main__":
|
118 |
+
main()
|
test_generic_stream.py
CHANGED
@@ -6,24 +6,33 @@ from stream_utils import StreamResponseHandler
|
|
6 |
# Load environment variables
|
7 |
load_dotenv()
|
8 |
|
|
|
9 |
def main():
|
10 |
# Create a console for rich output
|
11 |
handler = StreamResponseHandler()
|
12 |
-
|
13 |
# Parse command line arguments
|
14 |
-
parser = argparse.ArgumentParser(
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
args = parser.parse_args()
|
18 |
-
|
19 |
# Check if the server is running
|
20 |
if not handler.check_server_health():
|
21 |
sys.exit(1)
|
22 |
-
|
23 |
# Stream the generic request
|
24 |
endpoint_url = "http://localhost:8000/test/stream"
|
25 |
params = {"query": args.query}
|
26 |
handler.stream_response(endpoint_url, params=params, title="Generic Agent Response")
|
27 |
|
|
|
28 |
if __name__ == "__main__":
|
29 |
-
main()
|
|
|
6 |
# Load environment variables
|
7 |
load_dotenv()
|
8 |
|
9 |
+
|
10 |
def main():
|
11 |
# Create a console for rich output
|
12 |
handler = StreamResponseHandler()
|
13 |
+
|
14 |
# Parse command line arguments
|
15 |
+
parser = argparse.ArgumentParser(
|
16 |
+
description="Test the generic agent streaming API."
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--query",
|
20 |
+
"-q",
|
21 |
+
required=True,
|
22 |
+
help="The query or message to send to the generic agent.",
|
23 |
+
)
|
24 |
+
|
25 |
args = parser.parse_args()
|
26 |
+
|
27 |
# Check if the server is running
|
28 |
if not handler.check_server_health():
|
29 |
sys.exit(1)
|
30 |
+
|
31 |
# Stream the generic request
|
32 |
endpoint_url = "http://localhost:8000/test/stream"
|
33 |
params = {"query": args.query}
|
34 |
handler.stream_response(endpoint_url, params=params, title="Generic Agent Response")
|
35 |
|
36 |
+
|
37 |
if __name__ == "__main__":
|
38 |
+
main()
|