simonlee-cb commited on
Commit
583b7ad
·
1 Parent(s): c16bc85

refactor: clean up

Browse files
image_edit_chat.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
3
  import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
@@ -9,10 +9,9 @@ from pydantic_ai.messages import (
9
  ToolCallPart,
10
  ToolReturnPart
11
  )
12
- from src.agents.mask_generation_agent import EditImageResult
13
- from pydantic_ai.agent import Agent
14
  from pydantic_ai.models.openai import OpenAIModel
15
- model = OpenAIModel(
 
16
  "gpt-4o",
17
  api_key=os.environ.get("OPENAI_API_KEY"),
18
  )
@@ -56,12 +55,6 @@ EXAMPLES = [
56
  }
57
  ]
58
 
59
- simple_agent = Agent(
60
- model,
61
- system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
62
- deps_type=ImageEditDeps
63
- )
64
-
65
  load_dotenv()
66
 
67
  def build_user_message(chat_input):
@@ -142,7 +135,7 @@ async def stream_from_agent(chat_input, chatbot, past_messages, current_image):
142
  )
143
 
144
  # Run the agent
145
- async with mask_generation_agent.run_stream(
146
  messages,
147
  deps=deps,
148
  message_history=past_messages
 
1
  import gradio as gr
2
+ from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
3
  import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
 
9
  ToolCallPart,
10
  ToolReturnPart
11
  )
 
 
12
  from pydantic_ai.models.openai import OpenAIModel
13
+
14
+ model = OpenAIModel(
15
  "gpt-4o",
16
  api_key=os.environ.get("OPENAI_API_KEY"),
17
  )
 
55
  }
56
  ]
57
 
 
 
 
 
 
 
58
  load_dotenv()
59
 
60
  def build_user_message(chat_input):
 
135
  )
136
 
137
  # Run the agent
138
+ async with image_edit_agent.run_stream(
139
  messages,
140
  deps=deps,
141
  message_history=past_messages
image_edit_demo.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
3
  import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
@@ -7,7 +7,6 @@ from dotenv import load_dotenv
7
  from pydantic_ai.messages import (
8
  ToolReturnPart
9
  )
10
- from src.agents.mask_generation_agent import EditImageResult
11
  from src.utils import upload_image
12
  load_dotenv()
13
 
@@ -31,7 +30,7 @@ async def process_edit(image, instruction):
31
  hopter_client=hopter,
32
  mask_service=mask_service
33
  )
34
- result = await mask_generation_agent.run(
35
  messages,
36
  deps=deps
37
  )
 
1
  import gradio as gr
2
+ from src.agents.image_edit_agent import image_edit_agent, ImageEditDeps, EditImageResult
3
  import os
4
  from src.hopter.client import Hopter, Environment
5
  from src.services.generate_mask import GenerateMaskService
 
7
  from pydantic_ai.messages import (
8
  ToolReturnPart
9
  )
 
10
  from src.utils import upload_image
11
  load_dotenv()
12
 
 
30
  hopter_client=hopter,
31
  mask_service=mask_service
32
  )
33
+ result = await image_edit_agent.run(
34
  messages,
35
  deps=deps
36
  )
src/agents/image-edit-agent.py DELETED
@@ -1,105 +0,0 @@
1
- from pydantic_ai import Agent, RunContext
2
- from pydantic_ai.settings import ModelSettings
3
- from pydantic_ai.models.openai import OpenAIModel
4
- from dotenv import load_dotenv
5
- import os
6
- import asyncio
7
- from src.utils import image_path_to_base64
8
- from dataclasses import dataclass
9
-
10
- load_dotenv()
11
-
12
- @dataclass
13
- class ImageEditDeps:
14
- edit_instruction: str
15
- image_url: str
16
-
17
- model = OpenAIModel(
18
- "gpt-4o",
19
- api_key=os.environ.get("OPENAI_API_KEY"),
20
- )
21
-
22
- image_edit_agent = Agent(
23
- model,
24
- system_prompt=[
25
- 'Be concise, reply with one sentence.',
26
- "You are an image editing agent. You will be given an image and an editing instruction. Use the tools available to you and come up with a plan to edit the image according to the instruction."
27
- ],
28
- deps_type=ImageEditDeps
29
- )
30
-
31
-
32
- @image_edit_agent.tool
33
- async def identify_editing_subject(ctx: RunContext[ImageEditDeps]) -> str:
34
- """
35
- Identify the subject of the image editing instruction.
36
-
37
- Args:
38
- instruction: The image editing instruction.
39
- image_url: The URL of the image.
40
-
41
- Returns:
42
- The subject of the image editing instruction.
43
- """
44
- messages = [
45
- {
46
- "type": "text",
47
- "text": ctx.deps.edit_instruction
48
- },
49
- {
50
- "type": "image_url",
51
- "image_url": {
52
- "url": ctx.deps.image_url
53
- }
54
- }
55
- ]
56
- r = await mask_generation_agent.run(messages, usage=ctx.usage, deps=ctx.deps)
57
- return r.data
58
-
59
- mask_generation_agent = Agent(
60
- model,
61
- system_prompt=[
62
- "I will give you an editing instruction of the image. Please output the object needed to be edited.",
63
- "You only need to output the basic description of the object in no more than 5 words.",
64
- "The output should only contain one noun.",
65
- "For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else."
66
- ],
67
- deps_type=ImageEditDeps
68
- )
69
-
70
- @mask_generation_agent.tool
71
- async def generate_mask(ctx: RunContext[ImageEditDeps], mask_subject: str) -> str:
72
- """
73
- Generate a mask for the image editing instruction.
74
- """
75
- pass
76
-
77
- async def main():
78
- image_file_path = "./assets/lakeview.jpg"
79
- image_base64 = image_path_to_base64(image_file_path)
80
- image_url = f"data:image/jpeg;base64,{image_base64}"
81
-
82
- prompt = "remove the light post"
83
- messages = [
84
- {
85
- "type": "text",
86
- "text": prompt
87
- },
88
- {
89
- "type": "image_url",
90
- "image_url": {
91
- "url": image_url
92
- }
93
- }
94
- ]
95
-
96
- deps = ImageEditDeps(
97
- edit_instruction=prompt,
98
- image_url=image_url
99
- )
100
- r = await mask_generation_agent.run(messages, deps=deps)
101
- print(r.data)
102
-
103
-
104
- if __name__ == "__main__":
105
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agents/image_edit_agent.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_ai import Agent, RunContext
2
+ from pydantic_ai.models.openai import OpenAIModel
3
+ from dotenv import load_dotenv
4
+ import os
5
+ import asyncio
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+ import logfire
9
+ from src.services.generate_mask import GenerateMaskService
10
+ from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
11
+ from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
12
+ import base64
13
+ import tempfile
14
+
15
+ load_dotenv()
16
+
17
+ logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
18
+ logfire.instrument_openai()
19
+
20
+ system_prompt = """
21
+ I will give you an editing instruction of the image.
22
+ if the edit instruction involved modifying parts of the image, please generate a mask for it.
23
+ if images are not provided, ask the user to provide an image.
24
+ """
25
+
26
+ @dataclass
27
+ class ImageEditDeps:
28
+ edit_instruction: str
29
+ hopter_client: Hopter
30
+ mask_service: GenerateMaskService
31
+ image_url: Optional[str] = None
32
+
33
+ model = OpenAIModel(
34
+ "gpt-4o",
35
+ api_key=os.environ.get("OPENAI_API_KEY"),
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class EditImageResult:
41
+ edited_image_url: str
42
+
43
+ image_edit_agent = Agent(
44
+ model,
45
+ system_prompt=system_prompt,
46
+ deps_type=ImageEditDeps
47
+ )
48
+
49
+ def upload_image_from_base64(base64_image: str) -> str:
50
+ image_format = base64_image.split(",")[0]
51
+ image_data = base64.b64decode(base64_image.split(",")[1])
52
+ suffix = ".jpg" if image_format == "image/jpeg" else ".png"
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
54
+ temp_filename = temp_file.name
55
+ with open(temp_filename, "wb") as f:
56
+ f.write(image_data)
57
+ return upload_image(temp_filename)
58
+
59
+ @image_edit_agent.tool
60
+ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
61
+ """
62
+ Use this tool to edit an object in the image. for example:
63
+ - remove the pole
64
+ - replace the dog with a cat
65
+ - change the background to a beach
66
+ - remove the person in the image
67
+ - change the hair color to red
68
+ - change the hat to a cap
69
+ """
70
+ edit_instruction = ctx.deps.edit_instruction
71
+ image_url = ctx.deps.image_url
72
+ mask_service = ctx.deps.mask_service
73
+ hopter_client = ctx.deps.hopter_client
74
+
75
+ image_uri = download_image_to_data_uri(image_url)
76
+
77
+ # Generate mask
78
+ mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
79
+ mask = mask_service.generate_mask(mask_instruction, image_uri)
80
+
81
+ # Magic replace
82
+ input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
83
+ result = hopter_client.magic_replace(input)
84
+ uploaded_image = upload_image_from_base64(result.base64_image)
85
+ return EditImageResult(edited_image_url=uploaded_image)
86
+
87
+ @image_edit_agent.tool
88
+ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
89
+ """
90
+ run super resolution, upscale, or enhance the image
91
+ """
92
+ image_url = ctx.deps.image_url
93
+ hopter_client = ctx.deps.hopter_client
94
+
95
+ image_uri = download_image_to_data_uri(image_url)
96
+
97
+ input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
98
+ result = hopter_client.super_resolution(input)
99
+ uploaded_image = upload_image_from_base64(result.scaled_image)
100
+ return EditImageResult(edited_image_url=uploaded_image)
101
+
102
+ async def main():
103
+ image_file_path = "./assets/lakeview.jpg"
104
+ image_url = image_path_to_uri(image_file_path)
105
+
106
+ prompt = "remove the light post"
107
+ messages = [
108
+ {
109
+ "type": "text",
110
+ "text": prompt
111
+ },
112
+ {
113
+ "type": "image_url",
114
+ "image_url": {
115
+ "url": image_url
116
+ }
117
+ }
118
+ ]
119
+
120
+ # Initialize services
121
+ hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
122
+ mask_service = GenerateMaskService(hopter=hopter)
123
+
124
+ # Initialize dependencies
125
+ deps = ImageEditDeps(
126
+ edit_instruction=prompt,
127
+ image_url=image_url,
128
+ hopter_client=hopter,
129
+ mask_service=mask_service
130
+ )
131
+ async with image_edit_agent.run_stream(
132
+ messages,
133
+ deps=deps
134
+ ) as result:
135
+ async for message in result.stream():
136
+ print(message)
137
+
138
+
139
+ if __name__ == "__main__":
140
+ asyncio.run(main())
src/agents/mask_generation_agent.py CHANGED
@@ -8,11 +8,9 @@ from typing import Optional
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
11
- from src.services.image_uploader import ImageUploader
12
  from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
13
  import base64
14
  import tempfile
15
- from PIL import Image
16
 
17
  load_dotenv()
18
 
@@ -20,127 +18,52 @@ logfire.configure(token=os.environ.get("LOGFIRE_TOKEN"))
20
  logfire.instrument_openai()
21
 
22
  system_prompt = """
23
- I will give you an editing instruction of the image.
24
- if the edit instruction involved modifying parts of the image, please generate a mask for it.
25
- if images are not provided, ask the user to provide an image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
 
28
- @dataclass
29
- class ImageEditDeps:
30
- edit_instruction: str
31
- hopter_client: Hopter
32
- mask_service: GenerateMaskService
33
- image_url: Optional[str] = None
34
-
35
  model = OpenAIModel(
36
  "gpt-4o",
37
  api_key=os.environ.get("OPENAI_API_KEY"),
38
  )
39
 
 
40
  @dataclass
41
  class MaskGenerationResult:
42
  mask_image_base64: str
43
 
44
-
45
- @dataclass
46
- class EditImageResult:
47
- edited_image_url: str
48
-
49
  mask_generation_agent = Agent(
50
  model,
51
- system_prompt=system_prompt,
52
- deps_type=ImageEditDeps
53
  )
54
 
55
- def upload_image_from_base64(base64_image: str) -> str:
56
- image_format = base64_image.split(",")[0]
57
- image_data = base64.b64decode(base64_image.split(",")[1])
58
- suffix = ".jpg" if image_format == "image/jpeg" else ".png"
59
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
60
- temp_filename = temp_file.name
61
- with open(temp_filename, "wb") as f:
62
- f.write(image_data)
63
- return upload_image(temp_filename)
64
-
65
- @mask_generation_agent.tool
66
- async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
67
- """
68
- Use this tool to edit an object in the image. for example:
69
- - remove the pole
70
- - replace the dog with a cat
71
- - change the background to a beach
72
- - remove the person in the image
73
- - change the hair color to red
74
- - change the hat to a cap
75
- """
76
- edit_instruction = ctx.deps.edit_instruction
77
- image_url = ctx.deps.image_url
78
- mask_service = ctx.deps.mask_service
79
- hopter_client = ctx.deps.hopter_client
80
-
81
- image_uri = download_image_to_data_uri(image_url)
82
-
83
- # Generate mask
84
- mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
85
- mask = mask_service.generate_mask(mask_instruction, image_uri)
86
-
87
- # Magic replace
88
- input = MagicReplaceInput(image=image_uri, mask=mask, prompt=mask_instruction.target_caption)
89
- result = hopter_client.magic_replace(input)
90
- uploaded_image = upload_image_from_base64(result.base64_image)
91
- return EditImageResult(edited_image_url=uploaded_image)
92
-
93
  @mask_generation_agent.tool
94
- async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
95
  """
96
- run super resolution, upscale, or enhance the image
97
  """
98
- image_url = ctx.deps.image_url
99
- hopter_client = ctx.deps.hopter_client
100
-
101
- image_uri = download_image_to_data_uri(image_url)
102
-
103
- input = SuperResolutionInput(image_b64=image_uri, scale=4, use_face_enhancement=False)
104
- result = hopter_client.super_resolution(input)
105
- uploaded_image = upload_image_from_base64(result.scaled_image)
106
- return EditImageResult(edited_image_url=uploaded_image)
107
-
108
- async def main():
109
- image_file_path = "./assets/lakeview.jpg"
110
- image_url = image_path_to_uri(image_file_path)
111
-
112
- prompt = "remove the light post"
113
- messages = [
114
- {
115
- "type": "text",
116
- "text": prompt
117
- },
118
- {
119
- "type": "image_url",
120
- "image_url": {
121
- "url": image_url
122
- }
123
- }
124
- ]
125
-
126
- # Initialize services
127
- hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
128
- mask_service = GenerateMaskService(hopter=hopter)
129
-
130
- # Initialize dependencies
131
- deps = ImageEditDeps(
132
- edit_instruction=prompt,
133
- image_url=image_url,
134
- hopter_client=hopter,
135
- mask_service=mask_service
136
- )
137
- async with mask_generation_agent.run_stream(
138
- messages,
139
- deps=deps
140
- ) as result:
141
- async for message in result.stream():
142
- print(message)
143
-
144
-
145
- if __name__ == "__main__":
146
- asyncio.run(main())
 
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
  from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
 
11
  from src.utils import image_path_to_uri, download_image_to_data_uri, upload_image
12
  import base64
13
  import tempfile
 
14
 
15
  load_dotenv()
16
 
 
18
  logfire.instrument_openai()
19
 
20
  system_prompt = """
21
+ I will give you an editing instruction of the image. Perform the following tasks:
22
+
23
+ <task_1>
24
+ Please output which type of editing category it is in.
25
+ You can choose from the following categories:
26
+ 1. Addition: Adding new objects within the images, e.g., add a bird
27
+ 2. Remove: Removing objects, e.g., remove the mask
28
+ 3. Local: Replace local parts of an object and later the object's attributes (e.g., make it smile) or alter an object's visual appearance without affecting its structure (e.g., change the cat to a dog)
29
+ 4. Global: Edit the entire image, e.g., let's see it in winter
30
+ 5. Background: Change the scene's background, e.g., have her walk on water, change the background to a beach, make the hedgehog in France, etc.
31
+ Only output a single word, e.g., 'Addition'.
32
+ </task_1>
33
+
34
+ <task_2>
35
+ Please output the subject needed to be edited. You only need to output the basic description of the object in no more than 5 words. The output should only contain one noun.
36
+
37
+ For example, the editing instruction is 'Change the white cat to a black dog'. Then you need to output: 'white cat'. Only output the new content. Do not output anything else.
38
+ </task_2>
39
+
40
+ <task_3>
41
+ Please describe the new content that should be present in the image after applying the instruction.
42
+
43
+ For example, if the original image content shows a grandmother wearing a mask and the instruction is 'remove the mask', your output should be: 'a grandmother'.
44
+ The output should only include elements that remain in the image after the edit and should not mention elements that have been changed or removed, such as 'mask' in this example.
45
+ Do not output 'sorry, xxx', even if it's a guess, directly output the answer you think is correct.
46
+ </task_3>
47
  """
48
 
 
 
 
 
 
 
 
49
  model = OpenAIModel(
50
  "gpt-4o",
51
  api_key=os.environ.get("OPENAI_API_KEY"),
52
  )
53
 
54
+
55
  @dataclass
56
  class MaskGenerationResult:
57
  mask_image_base64: str
58
 
 
 
 
 
 
59
  mask_generation_agent = Agent(
60
  model,
61
+ system_prompt=system_prompt
 
62
  )
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @mask_generation_agent.tool
65
+ async def generate_mask(edit_instruction: str, image_url: str) -> MaskGenerationResult:
66
  """
67
+ Use this tool to generate a mask for the image.
68
  """
69
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/services/generate_mask.py CHANGED
@@ -1,8 +1,6 @@
1
- from pydantic import BaseModel, Field
2
  from openai import OpenAI
3
  import os
4
  from dotenv import load_dotenv
5
- import base64
6
  import asyncio
7
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
8
  from src.models.generate_mask_instruction import GenerateMaskInstruction
 
 
1
  from openai import OpenAI
2
  import os
3
  from dotenv import load_dotenv
 
4
  import asyncio
5
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
6
  from src.models.generate_mask_instruction import GenerateMaskInstruction
src/services/google_cloud_image_upload.py CHANGED
@@ -3,10 +3,12 @@ from PIL import Image
3
  import os
4
  import uuid
5
  import tempfile
 
 
 
6
 
7
  def get_credentials():
8
  credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
9
-
10
  # create a temp file with the credentials
11
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp_file:
12
  temp_file.write(credentials_json_string)
 
3
  import os
4
  import uuid
5
  import tempfile
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
 
10
  def get_credentials():
11
  credentials_json_string = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
 
12
  # create a temp file with the credentials
13
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as temp_file:
14
  temp_file.write(credentials_json_string)