Spaces:
Running
Running
Commit
·
583b7ad
1
Parent(s):
c16bc85
refactor: clean up
Browse files- image_edit_chat.py +4 -11
- image_edit_demo.py +2 -3
- src/agents/image-edit-agent.py +0 -105
- src/agents/image_edit_agent.py +140 -0
- src/agents/mask_generation_agent.py +31 -108
- src/services/generate_mask.py +0 -2
- src/services/google_cloud_image_upload.py +3 -1
image_edit_chat.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from src.agents.
|
3 |
import os
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
@@ -9,10 +9,9 @@ from pydantic_ai.messages import (
|
|
9 |
ToolCallPart,
|
10 |
ToolReturnPart
|
11 |
)
|
12 |
-
from src.agents.mask_generation_agent import EditImageResult
|
13 |
-
from pydantic_ai.agent import Agent
|
14 |
from pydantic_ai.models.openai import OpenAIModel
|
15 |
-
|
|
|
16 |
"gpt-4o",
|
17 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
18 |
)
|
@@ -56,12 +55,6 @@ EXAMPLES = [
|
|
56 |
}
|
57 |
]
|
58 |
|
59 |
-
simple_agent = Agent(
|
60 |
-
model,
|
61 |
-
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
|
62 |
-
deps_type=ImageEditDeps
|
63 |
-
)
|
64 |
-
|
65 |
load_dotenv()
|
66 |
|
67 |
def build_user_message(chat_input):
|
@@ -142,7 +135,7 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
|
|
142 |
)
|
143 |
|
144 |
# Run the agent
|
145 |
-
async with
|
146 |
messages,
|
147 |
deps=deps,
|
148 |
message_history=past_messages
|
|
|
1 |
import gradio as gr
|
2 |
+
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
|
3 |
import os
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
|
|
9 |
ToolCallPart,
|
10 |
ToolReturnPart
|
11 |
)
|
|
|
|
|
12 |
from pydantic_ai.models.openai import OpenAIModel
|
13 |
+
|
14 |
+
model = OpenAIModel(
|
15 |
"gpt-4o",
|
16 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
17 |
)
|
|
|
55 |
}
|
56 |
]
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
load_dotenv()
|
59 |
|
60 |
def build_user_message(chat_input):
|
|
|
135 |
)
|
136 |
|
137 |
# Run the agent
|
138 |
+
async with image_edit_agent.run_stream(
|
139 |
messages,
|
140 |
deps=deps,
|
141 |
message_history=past_messages
|
image_edit_demo.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from src.agents.
|
3 |
import os
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
@@ -7,7 +7,6 @@ from dotenv import load_dotenv
|
|
7 |
from pydantic_ai.messages import (
|
8 |
ToolReturnPart
|
9 |
)
|
10 |
-
from src.agents.mask_generation_agent import EditImageResult
|
11 |
from src.utils import upload_image
|
12 |
load_dotenv()
|
13 |
|
@@ -31,7 +30,7 @@ async def process_edit(image, instruction):
|
|
31 |
hopter_client=hopter,
|
32 |
mask_service=mask_service
|
33 |
)
|
34 |
-
result = await
|
35 |
messages,
|
36 |
deps=deps
|
37 |
)
|
|
|
1 |
import gradio as gr
|
2 |
+
from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
|
3 |
import os
|
4 |
from src.hopter.client import Hopter, Environment
|
5 |
from src.services.generate_mask import GenerateMaskService
|
|
|
7 |
from pydantic_ai.messages import (
|
8 |
ToolReturnPart
|
9 |
)
|
|
|
10 |
from src.utils import upload_image
|
11 |
load_dotenv()
|
12 |
|
|
|
30 |
hopter_client=hopter,
|
31 |
mask_service=mask_service
|
32 |
)
|
33 |
+
result = await image_edit_agent.run(
|
34 |
messages,
|
35 |
deps=deps
|
36 |
)
|
src/agents/image-edit-agent.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
from pydantic_ai import Agent, RunContext
|
2 |
-
from pydantic_ai.settings import ModelSettings
|
3 |
-
from pydantic_ai.models.openai import OpenAIModel
|
4 |
-
from dotenv import load_dotenv
|
5 |
-
import os
|
6 |
-
import asyncio
|
7 |
-
from src.utils import image_path_to_base64
|
8 |
-
from dataclasses import dataclass
|
9 |
-
|
10 |
-
load_dotenv()
|
11 |
-
|
12 |
-
@dataclass
|
13 |
-
class ImageEditDeps:
|
14 |
-
edit_instruction: str
|
15 |
-
image_url: str
|
16 |
-
|
17 |
-
model = OpenAIModel(
|
18 |
-
"gpt-4o",
|
19 |
-
api_key=os.environ.get("OPENAI_API_KEY"),
|
20 |
-
)
|
21 |
-
|
22 |
-
image_edit_agent = Agent(
|
23 |
-
model,
|
24 |
-
system_prompt=[
|
25 |
-
'Be concise, reply with one sentence.',
|
26 |
-
"You are an image editing agent. You will be given an image and an editing instruction. Use the tools available to you and come up with a plan to edit the image according to the instruction."
|
27 |
-
],
|
28 |
-
deps_type=ImageEditDeps
|
29 |
-
)
|
30 |
-
|
31 |
-
|
32 |
-
@image_edit_agent.tool
|
33 |
-
async def identify_editing_subject(ctx: RunContext[ImageEditDeps]) -> str:
|
34 |
-
"""
|
35 |
-
Identify the subject of the image editing instruction.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
instruction: The image editing instruction.
|
39 |
-
image_url: The URL of the image.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
The subject of the image editing instruction.
|
43 |
-
"""
|
44 |
-
messages = [
|
45 |
-
{
|
46 |
-
"type": "text",
|
47 |
-
"text": ctx.deps.edit_instruction
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"type": "image_url",
|
51 |
-
"image_url": {
|
52 |
-
"url": ctx.deps.image_url
|
53 |
-
}
|
54 |
-
}
|
55 |
-
]
|
56 |
-
r = await mask_generation_agent.run(messages, usage=ctx.usage, deps=ctx.deps)
|
57 |
-
return r.data
|
58 |
-
|
59 |
-
mask_generation_agent = Agent(
|
60 |
-
model,
|
61 |
-
system_prompt=[
|
62 |
-
"I will give you an editing instruction of the image. Please output the object needed to be edited.",
|
63 |
-
"You only need to output the basic description of the object in no more than 5 words.",
|
64 |
-
"The output should only contain one noun.",
|
65 |
-
"For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
|
66 |
-
],
|
67 |
-
deps_type=ImageEditDeps
|
68 |
-
)
|
69 |
-
|
70 |
-
@mask_generation_agent.tool
|
71 |
-
async def generate_mask(ctx: RunContext[ImageEditDeps], mask_subject: str) -> str:
|
72 |
-
"""
|
73 |
-
Generate a mask for the image editing instruction.
|
74 |
-
"""
|
75 |
-
pass
|
76 |
-
|
77 |
-
async def main():
|
78 |
-
image_file_path = "./assets/lakeview.jpg"
|
79 |
-
image_base64 = image_path_to_base64(image_file_path)
|
80 |
-
image_url = f"data:image/jpeg;base64,{image_base64}"
|
81 |
-
|
82 |
-
prompt = "remove the light post"
|
83 |
-
messages = [
|
84 |
-
{
|
85 |
-
"type": "text",
|
86 |
-
"text": prompt
|
87 |
-
},
|
88 |
-
{
|
89 |
-
"type": "image_url",
|
90 |
-
"image_url": {
|
91 |
-
"url": image_url
|
92 |
-
}
|
93 |
-
}
|
94 |
-
]
|
95 |
-
|
96 |
-
deps = ImageEditDeps(
|
97 |
-
edit_instruction=prompt,
|
98 |
-
image_url=image_url
|
99 |
-
)
|
100 |
-
r = await mask_generation_agent.run(messages, deps=deps)
|
101 |
-
print(r.data)
|
102 |
-
|
103 |
-
|
104 |
-
if __name__ == "__main__":
|
105 |
-
asyncio.run(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/agents/image_edit_agent.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic_ai import Agent, RunContext
|
2 |
+
from pydantic_ai.models.openai import OpenAIModel
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
import asyncio
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional
|
8 |
+
import logfire
|
9 |
+
from src.services.generate_mask import GenerateMaskService
|
10 |
+
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
11 |
+
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
12 |
+
import base64
|
13 |
+
import tempfile
|
14 |
+
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
|
18 |
+
logfire.instrument_openai()
|
19 |
+
|
20 |
+
system_prompt = """
|
21 |
+
I will give you an editing instruction of the image.
|
22 |
+
if the edit instruction involved modifying parts of the image, please generate a mask for it.
|
23 |
+
if images are not provided, ask the user to provide an image.
|
24 |
+
"""
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class ImageEditDeps:
|
28 |
+
edit_instruction: str
|
29 |
+
hopter_client: Hopter
|
30 |
+
mask_service: GenerateMaskService
|
31 |
+
image_url: Optional[str] = None
|
32 |
+
|
33 |
+
model = OpenAIModel(
|
34 |
+
"gpt-4o",
|
35 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class EditImageResult:
|
41 |
+
edited_image_url: str
|
42 |
+
|
43 |
+
image_edit_agent = Agent(
|
44 |
+
model,
|
45 |
+
system_prompt=system_prompt,
|
46 |
+
deps_type=ImageEditDeps
|
47 |
+
)
|
48 |
+
|
49 |
+
def upload_image_from_base64(base64_image: str) -> str:
|
50 |
+
image_format = base64_image.split(",")[0]
|
51 |
+
image_data = base64.b64decode(base64_image.split(",")[1])
|
52 |
+
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
|
53 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
54 |
+
temp_filename = temp_file.name
|
55 |
+
with open(temp_filename, "wb") as f:
|
56 |
+
f.write(image_data)
|
57 |
+
return upload_image(temp_filename)
|
58 |
+
|
59 |
+
@image_edit_agent.tool
|
60 |
+
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
61 |
+
"""
|
62 |
+
Use this tool to edit an object in the image. for example:
|
63 |
+
- remove the pole
|
64 |
+
- replace the dog with a cat
|
65 |
+
- change the background to a beach
|
66 |
+
- remove the person in the image
|
67 |
+
- change the hair color to red
|
68 |
+
- change the hat to a cap
|
69 |
+
"""
|
70 |
+
edit_instruction = ctx.deps.edit_instruction
|
71 |
+
image_url = ctx.deps.image_url
|
72 |
+
mask_service = ctx.deps.mask_service
|
73 |
+
hopter_client = ctx.deps.hopter_client
|
74 |
+
|
75 |
+
image_uri = download_image_to_data_uri(image_url)
|
76 |
+
|
77 |
+
# Generate mask
|
78 |
+
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
79 |
+
mask = mask_service.generate_mask(mask_instruction, image_uri)
|
80 |
+
|
81 |
+
# Magic replace
|
82 |
+
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
|
83 |
+
result = hopter_client.magic_replace(input)
|
84 |
+
uploaded_image = upload_image_from_base64(result.base64_image)
|
85 |
+
return EditImageResult(edited_image_url=uploaded_image)
|
86 |
+
|
87 |
+
@image_edit_agent.tool
|
88 |
+
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
89 |
+
"""
|
90 |
+
run super resolution, upscale, or enhance the image
|
91 |
+
"""
|
92 |
+
image_url = ctx.deps.image_url
|
93 |
+
hopter_client = ctx.deps.hopter_client
|
94 |
+
|
95 |
+
image_uri = download_image_to_data_uri(image_url)
|
96 |
+
|
97 |
+
input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
|
98 |
+
result = hopter_client.super_resolution(input)
|
99 |
+
uploaded_image = upload_image_from_base64(result.scaled_image)
|
100 |
+
return EditImageResult(edited_image_url=uploaded_image)
|
101 |
+
|
102 |
+
async def main():
|
103 |
+
image_file_path = "./assets/lakeview.jpg"
|
104 |
+
image_url = image_path_to_uri(image_file_path)
|
105 |
+
|
106 |
+
prompt = "remove the light post"
|
107 |
+
messages = [
|
108 |
+
{
|
109 |
+
"type": "text",
|
110 |
+
"text": prompt
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"type": "image_url",
|
114 |
+
"image_url": {
|
115 |
+
"url": image_url
|
116 |
+
}
|
117 |
+
}
|
118 |
+
]
|
119 |
+
|
120 |
+
# Initialize services
|
121 |
+
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
122 |
+
mask_service = GenerateMaskService(hopter=hopter)
|
123 |
+
|
124 |
+
# Initialize dependencies
|
125 |
+
deps = ImageEditDeps(
|
126 |
+
edit_instruction=prompt,
|
127 |
+
image_url=image_url,
|
128 |
+
hopter_client=hopter,
|
129 |
+
mask_service=mask_service
|
130 |
+
)
|
131 |
+
async with image_edit_agent.run_stream(
|
132 |
+
messages,
|
133 |
+
deps=deps
|
134 |
+
) as result:
|
135 |
+
async for message in result.stream():
|
136 |
+
print(message)
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
asyncio.run(main())
|
src/agents/mask_generation_agent.py
CHANGED
@@ -8,11 +8,9 @@ from typing import Optional
|
|
8 |
import logfire
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
11 |
-
from src.services.image_uploader import ImageUploader
|
12 |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
13 |
import base64
|
14 |
import tempfile
|
15 |
-
from PIL import Image
|
16 |
|
17 |
load_dotenv()
|
18 |
|
@@ -20,127 +18,52 @@ logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
|
|
20 |
logfire.instrument_openai()
|
21 |
|
22 |
system_prompt = """
|
23 |
-
I will give you an editing instruction of the image.
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
"""
|
27 |
|
28 |
-
@dataclass
|
29 |
-
class ImageEditDeps:
|
30 |
-
edit_instruction: str
|
31 |
-
hopter_client: Hopter
|
32 |
-
mask_service: GenerateMaskService
|
33 |
-
image_url: Optional[str] = None
|
34 |
-
|
35 |
model = OpenAIModel(
|
36 |
"gpt-4o",
|
37 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
38 |
)
|
39 |
|
|
|
40 |
@dataclass
|
41 |
class MaskGenerationResult:
|
42 |
mask_image_base64: str
|
43 |
|
44 |
-
|
45 |
-
@dataclass
|
46 |
-
class EditImageResult:
|
47 |
-
edited_image_url: str
|
48 |
-
|
49 |
mask_generation_agent = Agent(
|
50 |
model,
|
51 |
-
system_prompt=system_prompt
|
52 |
-
deps_type=ImageEditDeps
|
53 |
)
|
54 |
|
55 |
-
def upload_image_from_base64(base64_image: str) -> str:
|
56 |
-
image_format = base64_image.split(",")[0]
|
57 |
-
image_data = base64.b64decode(base64_image.split(",")[1])
|
58 |
-
suffix = ".jpg" if image_format == "image/jpeg" else ".png"
|
59 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
60 |
-
temp_filename = temp_file.name
|
61 |
-
with open(temp_filename, "wb") as f:
|
62 |
-
f.write(image_data)
|
63 |
-
return upload_image(temp_filename)
|
64 |
-
|
65 |
-
@mask_generation_agent.tool
|
66 |
-
async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
67 |
-
"""
|
68 |
-
Use this tool to edit an object in the image. for example:
|
69 |
-
- remove the pole
|
70 |
-
- replace the dog with a cat
|
71 |
-
- change the background to a beach
|
72 |
-
- remove the person in the image
|
73 |
-
- change the hair color to red
|
74 |
-
- change the hat to a cap
|
75 |
-
"""
|
76 |
-
edit_instruction = ctx.deps.edit_instruction
|
77 |
-
image_url = ctx.deps.image_url
|
78 |
-
mask_service = ctx.deps.mask_service
|
79 |
-
hopter_client = ctx.deps.hopter_client
|
80 |
-
|
81 |
-
image_uri = download_image_to_data_uri(image_url)
|
82 |
-
|
83 |
-
# Generate mask
|
84 |
-
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
85 |
-
mask = mask_service.generate_mask(mask_instruction, image_uri)
|
86 |
-
|
87 |
-
# Magic replace
|
88 |
-
input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
|
89 |
-
result = hopter_client.magic_replace(input)
|
90 |
-
uploaded_image = upload_image_from_base64(result.base64_image)
|
91 |
-
return EditImageResult(edited_image_url=uploaded_image)
|
92 |
-
|
93 |
@mask_generation_agent.tool
|
94 |
-
async def
|
95 |
"""
|
96 |
-
|
97 |
"""
|
98 |
-
|
99 |
-
hopter_client = ctx.deps.hopter_client
|
100 |
-
|
101 |
-
image_uri = download_image_to_data_uri(image_url)
|
102 |
-
|
103 |
-
input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
|
104 |
-
result = hopter_client.super_resolution(input)
|
105 |
-
uploaded_image = upload_image_from_base64(result.scaled_image)
|
106 |
-
return EditImageResult(edited_image_url=uploaded_image)
|
107 |
-
|
108 |
-
async def main():
|
109 |
-
image_file_path = "./assets/lakeview.jpg"
|
110 |
-
image_url = image_path_to_uri(image_file_path)
|
111 |
-
|
112 |
-
prompt = "remove the light post"
|
113 |
-
messages = [
|
114 |
-
{
|
115 |
-
"type": "text",
|
116 |
-
"text": prompt
|
117 |
-
},
|
118 |
-
{
|
119 |
-
"type": "image_url",
|
120 |
-
"image_url": {
|
121 |
-
"url": image_url
|
122 |
-
}
|
123 |
-
}
|
124 |
-
]
|
125 |
-
|
126 |
-
# Initialize services
|
127 |
-
hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
|
128 |
-
mask_service = GenerateMaskService(hopter=hopter)
|
129 |
-
|
130 |
-
# Initialize dependencies
|
131 |
-
deps = ImageEditDeps(
|
132 |
-
edit_instruction=prompt,
|
133 |
-
image_url=image_url,
|
134 |
-
hopter_client=hopter,
|
135 |
-
mask_service=mask_service
|
136 |
-
)
|
137 |
-
async with mask_generation_agent.run_stream(
|
138 |
-
messages,
|
139 |
-
deps=deps
|
140 |
-
) as result:
|
141 |
-
async for message in result.stream():
|
142 |
-
print(message)
|
143 |
-
|
144 |
-
|
145 |
-
if __name__ == "__main__":
|
146 |
-
asyncio.run(main())
|
|
|
8 |
import logfire
|
9 |
from src.services.generate_mask import GenerateMaskService
|
10 |
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
|
|
11 |
from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
|
12 |
import base64
|
13 |
import tempfile
|
|
|
14 |
|
15 |
load_dotenv()
|
16 |
|
|
|
18 |
logfire.instrument_openai()
|
19 |
|
20 |
system_prompt = """
|
21 |
+
I will give you an editing instruction of the image. Perform the following tasks:
|
22 |
+
|
23 |
+
<task_1>
|
24 |
+
Please output which type of editing category it is in.
|
25 |
+
You can choose from the following categories:
|
26 |
+
1. Addition: Adding new objects within the images, e.g., add a bird
|
27 |
+
2. Remove: Removing objects, e.g., remove the mask
|
28 |
+
3. Local: Replace local parts of an object and later the object's attributes (e.g., make it smile) or alter an object's visual appearance without affecting its structure (e.g., change the cat to a dog)
|
29 |
+
4. Global: Edit the entire image, e.g., let's see it in winter
|
30 |
+
5. Background: Change the scene's background, e.g., have her walk on water, change the background to a beach, make the hedgehog in France, etc.
|
31 |
+
Only output a single word, e.g., 'Addition'.
|
32 |
+
</task_1>
|
33 |
+
|
34 |
+
<task_2>
|
35 |
+
Please output the subject needed to be edited. You only need to output the basic description of the object in no more than 5 words. The output should only contain one noun.
|
36 |
+
|
37 |
+
For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else.
|
38 |
+
</task_2>
|
39 |
+
|
40 |
+
<task_3>
|
41 |
+
Please describe the new content that should be present in the image after applying the instruction.
|
42 |
+
|
43 |
+
For example, if the original image content shows a grandmother wearing a mask and the instruction is 'remove the mask', your output should be: 'a grandmother'.
|
44 |
+
The output should only include elements that remain in the image after the edit and should not mention elements that have been changed or removed, such as 'mask' in this example.
|
45 |
+
Do not output 'sorry, xxx', even if it's a guess, directly output the answer you think is correct.
|
46 |
+
</task_3>
|
47 |
"""
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
model = OpenAIModel(
|
50 |
"gpt-4o",
|
51 |
api_key=os.environ.get("OPENAI_API_KEY"),
|
52 |
)
|
53 |
|
54 |
+
|
55 |
@dataclass
|
56 |
class MaskGenerationResult:
|
57 |
mask_image_base64: str
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
mask_generation_agent = Agent(
|
60 |
model,
|
61 |
+
system_prompt=system_prompt
|
|
|
62 |
)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@mask_generation_agent.tool
|
65 |
+
async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
|
66 |
"""
|
67 |
+
Use this tool to generate a mask for the image.
|
68 |
"""
|
69 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/services/generate_mask.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
-
from pydantic import BaseModel, Field
|
2 |
from openai import OpenAI
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
5 |
-
import base64
|
6 |
import asyncio
|
7 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
8 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
|
|
|
|
1 |
from openai import OpenAI
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
|
|
4 |
import asyncio
|
5 |
from src.hopter.client import Hopter, RamGroundedSamInput, Environment
|
6 |
from src.models.generate_mask_instruction import GenerateMaskInstruction
|
src/services/google_cloud_image_upload.py
CHANGED
@@ -3,10 +3,12 @@ from PIL import Image
|
|
3 |
import os
|
4 |
import uuid
|
5 |
import tempfile
|
|
|
|
|
|
|
6 |
|
7 |
def get_credentials():
|
8 |
credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
9 |
-
|
10 |
# create a temp file with the credentials
|
11 |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp_file:
|
12 |
temp_file.write(credentials_json_string)
|
|
|
3 |
import os
|
4 |
import uuid
|
5 |
import tempfile
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
|
10 |
def get_credentials():
|
11 |
credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
|
|
12 |
# create a temp file with the credentials
|
13 |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp_file:
|
14 |
temp_file.write(credentials_json_string)
|