simonlee-cb commited on
Commit
9e822e4
·
1 Parent(s): d2eb85a

feat: working magic replace

Browse files
agent.py CHANGED
@@ -17,6 +17,18 @@ from pydantic_ai.messages import (
17
  )
18
  import asyncio
19
  from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class ChatMessage(TypedDict):
22
  """Format of messages sent to the browser/API."""
@@ -60,7 +72,9 @@ async def run_agent(user_input: str, image_b64: str):
60
  ]
61
  deps = ImageEditDeps(
62
  edit_instruction=user_input,
63
- image_url=image_b64
 
 
64
  )
65
  async with mask_generation_agent.run_stream(
66
  messages,
 
17
  )
18
  import asyncio
19
  from src.agents.mask_generation_agent import mask_generation_agent, ImageEditDeps
20
+ from src.hopter.client import Hopter, Environment
21
+ import os
22
+ from src.services.generate_mask import GenerateMaskService
23
+ from dotenv import load_dotenv
24
+
25
+ load_dotenv()
26
+
27
+ hopter = Hopter(
28
+ api_key=os.getenv("HOPTER_API_KEY"),
29
+ environment=Environment.STAGING
30
+ )
31
+ mask_service = GenerateMaskService(hopter=hopter)
32
 
33
  class ChatMessage(TypedDict):
34
  """Format of messages sent to the browser/API."""
 
72
  ]
73
  deps = ImageEditDeps(
74
  edit_instruction=user_input,
75
+ image_url=image_b64,
76
+ hopter_client=hopter,
77
+ mask_service=mask_service
78
  )
79
  async with mask_generation_agent.run_stream(
80
  messages,
src/agents/mask_generation_agent.py CHANGED
@@ -1,5 +1,4 @@
1
  from pydantic_ai import Agent, RunContext
2
- from pydantic_ai.settings import ModelSettings
3
  from pydantic_ai.models.openai import OpenAIModel
4
  from dotenv import load_dotenv
5
  import os
@@ -8,6 +7,7 @@ import base64
8
  from dataclasses import dataclass
9
  import logfire
10
  from src.services.generate_mask import GenerateMaskService
 
11
 
12
  load_dotenv()
13
 
@@ -23,6 +23,8 @@ if the edit instruction involved modifying parts of the image, please generate a
23
  class ImageEditDeps:
24
  edit_instruction: str
25
  image_url: str
 
 
26
 
27
  model = OpenAIModel(
28
  "gpt-4o",
@@ -33,26 +35,53 @@ model = OpenAIModel(
33
  class MaskGenerationResult:
34
  mask_image_base64: str
35
 
 
 
 
 
 
36
  mask_generation_agent = Agent(
37
  model,
38
  system_prompt=system_prompt,
39
  deps_type=ImageEditDeps
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @mask_generation_agent.tool
43
- async def generate_mask(ctx: RunContext[ImageEditDeps]) -> MaskGenerationResult:
44
  """
45
- Generate a mask for the image editing instruction.
46
  """
47
- print("Invoking generate_mask tool")
48
- service = GenerateMaskService()
49
- mask_instruction = await service.get_mask_generation_instruction(ctx.deps.edit_instruction, ctx.deps.image_url)
50
- response = mask_instruction.model_dump_json(indent=4)
51
- print(f"generate_mask tool response: {response}")
 
 
 
 
 
52
 
53
- mask = await service.generate_mask(mask_instruction, ctx.deps.image_url)
54
- print("Exiting generate_mask tool")
55
- return MaskGenerationResult(mask_image_base64=mask)
 
 
56
 
57
  async def main():
58
  image_file_path = "./assets/lakeview.jpg"
@@ -75,12 +104,23 @@ async def main():
75
  }
76
  ]
77
 
 
 
 
 
 
78
  deps = ImageEditDeps(
79
  edit_instruction=prompt,
80
- image_url=image_url
 
 
81
  )
82
- r = await mask_generation_agent.run(messages, deps=deps)
83
- print(r.all_messages())
 
 
 
 
84
 
85
 
86
  if __name__ == "__main__":
 
1
  from pydantic_ai import Agent, RunContext
 
2
  from pydantic_ai.models.openai import OpenAIModel
3
  from dotenv import load_dotenv
4
  import os
 
7
  from dataclasses import dataclass
8
  import logfire
9
  from src.services.generate_mask import GenerateMaskService
10
+ from src.hopter.client import Hopter, Environment, MagicReplaceInput
11
 
12
  load_dotenv()
13
 
 
23
  class ImageEditDeps:
24
  edit_instruction: str
25
  image_url: str
26
+ hopter_client: Hopter
27
+ mask_service: GenerateMaskService
28
 
29
  model = OpenAIModel(
30
  "gpt-4o",
 
35
  class MaskGenerationResult:
36
  mask_image_base64: str
37
 
38
+
39
+ @dataclass
40
+ class EditImageResult:
41
+ edited_image_base64: str
42
+
43
  mask_generation_agent = Agent(
44
  model,
45
  system_prompt=system_prompt,
46
  deps_type=ImageEditDeps
47
  )
48
 
49
+ # @mask_generation_agent.tool
50
+ # async def generate_mask(ctx: RunContext[ImageEditDeps]) -> MaskGenerationResult:
51
+ # """
52
+ # Generate a mask for the image editing instruction.
53
+ # """
54
+ # print("Invoking generate_mask tool")
55
+ # service = GenerateMaskService()
56
+ # mask_instruction = await service.get_mask_generation_instruction(ctx.deps.edit_instruction, ctx.deps.image_url)
57
+ # response = mask_instruction.model_dump_json(indent=4)
58
+ # print(f"generate_mask tool response: {response}")
59
+
60
+ # mask = await service.generate_mask(mask_instruction, ctx.deps.image_url)
61
+ # print("Exiting generate_mask tool")
62
+ # return MaskGenerationResult(mask_image_base64=mask)
63
+
64
  @mask_generation_agent.tool
65
+ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
66
  """
67
+ Edit an object in the image.
68
  """
69
+ print("Invoking edit_object tool")
70
+ edit_instruction = ctx.deps.edit_instruction
71
+ image_url = ctx.deps.image_url
72
+ mask_service = ctx.deps.mask_service
73
+ hopter_client = ctx.deps.hopter_client
74
+
75
+ # Generate mask
76
+ print("Generating mask")
77
+ mask_instruction = await mask_service.get_mask_generation_instruction(edit_instruction, image_url)
78
+ mask = await mask_service.generate_mask(mask_instruction, image_url)
79
 
80
+ # Magic replace
81
+ input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
82
+ result = await hopter_client.magic_replace(input)
83
+ print("Exiting edit_object tool: ", result)
84
+ return EditImageResult(edited_image_base64=result.base64_image)
85
 
86
  async def main():
87
  image_file_path = "./assets/lakeview.jpg"
 
104
  }
105
  ]
106
 
107
+ # Initialize services
108
+ hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
109
+ mask_service = GenerateMaskService(hopter=hopter)
110
+
111
+ # Initialize dependencies
112
  deps = ImageEditDeps(
113
  edit_instruction=prompt,
114
+ image_url=image_url,
115
+ hopter_client=hopter,
116
+ mask_service=mask_service
117
  )
118
+ async with mask_generation_agent.run_stream(
119
+ messages,
120
+ deps=deps
121
+ ) as result:
122
+ async for message in result.stream():
123
+ print(message)
124
 
125
 
126
  if __name__ == "__main__":
src/hopter/client.py CHANGED
@@ -32,6 +32,14 @@ class RamGroundedSamResult(BaseModel):
32
  confidence: float = Field(..., description="The confidence score of the mask.")
33
  bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
34
 
 
 
 
 
 
 
 
 
35
  class Hopter:
36
  def __init__(
37
  self,
@@ -64,7 +72,45 @@ class Hopter:
64
  except Exception as exc:
65
  print(f"An unexpected error occurred: {exc}")
66
 
67
- # async def _ram_grounded_sam(self, prompt: str, image_base64: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  async def main():
70
  hopter = Hopter(
@@ -77,12 +123,8 @@ async def main():
77
  image_base64 = base64.b64encode(image_bytes).decode("utf-8")
78
  image_url = f"data:image/jpeg;base64,{image_base64}"
79
 
80
- input = RamGroundedSamInput(
81
- text_prompt="pole",
82
- image_b64=image_url
83
- )
84
- mask = await hopter.generate_mask(input)
85
- print(mask)
86
-
87
  if __name__ == "__main__":
88
  asyncio.run(main())
 
32
  confidence: float = Field(..., description="The confidence score of the mask.")
33
  bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
34
 
35
+ class MagicReplaceInput(BaseModel):
36
+ image: str = Field(..., description="The image in base64 format.")
37
+ mask: str = Field(..., description="The mask in base64 format.")
38
+ prompt: str = Field(..., description="The prompt for the magic replace.")
39
+
40
+ class MagicReplaceResult(BaseModel):
41
+ base64_image: str = Field(..., description="The edited image in base64 format.")
42
+
43
  class Hopter:
44
  def __init__(
45
  self,
 
72
  except Exception as exc:
73
  print(f"An unexpected error occurred: {exc}")
74
 
75
+ async def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
76
+ print(f"Magic replacing with input: {input.prompt}")
77
+ try:
78
+ response = await self.client.post(
79
+ f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
80
+ headers={
81
+ "Authorization": f"Bearer {self.api_key}",
82
+ "Content-Type": "application/json"
83
+ },
84
+ json={
85
+ "input": input.model_dump()
86
+ }
87
+ )
88
+ print(response)
89
+ response.raise_for_status() # Raise an error for bad responses
90
+ instance = response.json().get("output")
91
+ print("Magic replaced.")
92
+ return MagicReplaceResult(**instance)
93
+ except httpx.HTTPStatusError as exc:
94
+ print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
95
+ except Exception as exc:
96
+ print(f"An unexpected error occurred: {exc}")
97
+
98
+ async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
99
+ input = RamGroundedSamInput(
100
+ text_prompt="pole",
101
+ image_b64=image_url
102
+ )
103
+ mask = await hopter.generate_mask(input)
104
+ return mask.mask_b64
105
+
106
+ async def test_magic_replace(hopter: Hopter, image_url: str, mask: str, prompt: str) -> str:
107
+ input = MagicReplaceInput(
108
+ image=image_url,
109
+ mask=mask,
110
+ prompt=prompt
111
+ )
112
+ result = await hopter.magic_replace(input)
113
+ return result.base64_image
114
 
115
  async def main():
116
  hopter = Hopter(
 
123
  image_base64 = base64.b64encode(image_bytes).decode("utf-8")
124
  image_url = f"data:image/jpeg;base64,{image_base64}"
125
 
126
+ mask = await test_generate_mask(hopter, image_url)
127
+ magic_replace_result = await test_magic_replace(hopter, image_url, mask, "remove the pole")
128
+ print(magic_replace_result)
 
 
 
 
129
  if __name__ == "__main__":
130
  asyncio.run(main())
src/services/generate_mask.py CHANGED
@@ -6,6 +6,8 @@ import base64
6
  import asyncio
7
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
8
  from src.models.generate_mask_instruction import GenerateMaskInstruction
 
 
9
  load_dotenv()
10
 
11
  system_prompt = """
@@ -38,13 +40,13 @@ Do not output 'sorry, xxx', even if it's a guess, directly output the answer you
38
  """
39
 
40
  class GenerateMaskService:
41
- def __init__(self):
42
  self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
43
- self.hopter = Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING)
44
  self.model = "gpt-4o"
 
 
45
 
46
  async def get_mask_generation_instruction(self, edit_instruction: str, image_url: str) -> GenerateMaskInstruction:
47
-
48
  messages = [
49
  {
50
  "role": "system",
@@ -93,13 +95,11 @@ class GenerateMaskService:
93
  return generate_mask_result.mask_b64
94
 
95
  async def main():
96
- service = GenerateMaskService()
97
  edit_instruction = "remove the light post"
98
  image_file_path = "./assets/lakeview.jpg"
99
  with open(image_file_path, "rb") as image_file:
100
- image_bytes = image_file.read()
101
- image_base64 = base64.b64encode(image_bytes).decode("utf-8")
102
- image_url = f"data:image/jpeg;base64,{image_base64}"
103
 
104
  instruction = await service.get_mask_generation_instruction(edit_instruction, image_url)
105
  print(instruction)
 
6
  import asyncio
7
  from src.hopter.client import Hopter, RamGroundedSamInput, Environment
8
  from src.models.generate_mask_instruction import GenerateMaskInstruction
9
+ from src.services.openai_file_upload import OpenAIFileUpload
10
+
11
  load_dotenv()
12
 
13
  system_prompt = """
 
40
  """
41
 
42
  class GenerateMaskService:
43
+ def __init__(self, hopter: Hopter):
44
  self.llm = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
 
45
  self.model = "gpt-4o"
46
+ self.openai_file_upload = OpenAIFileUpload()
47
+ self.hopter = hopter
48
 
49
  async def get_mask_generation_instruction(self, edit_instruction: str, image_url: str) -> GenerateMaskInstruction:
 
50
  messages = [
51
  {
52
  "role": "system",
 
95
  return generate_mask_result.mask_b64
96
 
97
  async def main():
98
+ service = GenerateMaskService(Hopter(api_key=os.environ.get("HOPTER_API_KEY"), environment=Environment.STAGING))
99
  edit_instruction = "remove the light post"
100
  image_file_path = "./assets/lakeview.jpg"
101
  with open(image_file_path, "rb") as image_file:
102
+ image_url = service.openai_file_upload.upload_image(image_file.read(), "vision")
 
 
103
 
104
  instruction = await service.get_mask_generation_instruction(edit_instruction, image_url)
105
  print(instruction)
src/services/image_uploader.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Optional, Union
3
+ import base64
4
+ from pathlib import Path
5
+ import os
6
+ from pydantic import BaseModel
7
+
8
+ class ImageInfo(BaseModel):
9
+ filename: str
10
+ name: str
11
+ mime: str
12
+ extension: str
13
+ url: str
14
+
15
+ class ImgBBData(BaseModel):
16
+ id: str
17
+ title: str
18
+ url_viewer: str
19
+ url: str
20
+ display_url: str
21
+ width: int
22
+ height: int
23
+ size: int
24
+ time: int
25
+ expiration: int
26
+ image: ImageInfo
27
+ thumb: ImageInfo
28
+ medium: ImageInfo
29
+ delete_url: str
30
+
31
+ class ImgBBResponse(BaseModel):
32
+ data: ImgBBData
33
+ success: bool
34
+ status: int
35
+
36
+ class ImageUploader:
37
+ """A class to handle image uploads to ImgBB service."""
38
+
39
+ def __init__(self, api_key: str):
40
+ """
41
+ Initialize the ImageUploader with an API key.
42
+
43
+ Args:
44
+ api_key (str): The ImgBB API key
45
+ """
46
+ self.api_key = api_key
47
+ self.base_url = "https://api.imgbb.com/1/upload"
48
+
49
+ def upload(
50
+ self,
51
+ image: Union[str, bytes, Path],
52
+ name: Optional[str] = None,
53
+ expiration: Optional[int] = None
54
+ ) -> ImgBBResponse:
55
+ """
56
+ Upload an image to ImgBB.
57
+
58
+ Args:
59
+ image: Can be:
60
+ - A file path (str or Path)
61
+ - Base64 encoded string
62
+ - Base64 data URI (e.g., data:image/jpeg;base64,...)
63
+ - URL to an image
64
+ - Bytes of an image
65
+ name: Optional name for the uploaded file
66
+ expiration: Optional expiration time in seconds (60-15552000)
67
+
68
+ Returns:
69
+ ImgBBResponse containing the parsed upload response from ImgBB
70
+
71
+ Raises:
72
+ ValueError: If the image format is invalid or upload fails
73
+ requests.RequestException: If the API request fails
74
+ """
75
+ # Prepare the parameters
76
+ params = {'key': self.api_key}
77
+ if expiration:
78
+ if not 60 <= expiration <= 15552000:
79
+ raise ValueError("Expiration must be between 60 and 15552000 seconds")
80
+ params['expiration'] = expiration
81
+
82
+ # Handle different image input types
83
+ if isinstance(image, (str, Path)):
84
+ image_str = str(image)
85
+ files = {}
86
+ if os.path.isfile(image_str):
87
+ # It's a file path
88
+ with open(image_str, 'rb') as file:
89
+ files['image'] = file
90
+ elif image_str.startswith(('http://', 'https://')):
91
+ # It's a URL
92
+ files['image'] = (None, image_str)
93
+ elif image_str.startswith('data:image/'):
94
+ # It's a data URI
95
+ # Extract the base64 part after the comma
96
+ base64_data = image_str.split(',', 1)[1]
97
+ files['image'] = (None, base64_data)
98
+ else:
99
+ # Assume it's base64 data
100
+ files['image'] = (None, image_str)
101
+
102
+ if name:
103
+ files['name'] = (None, name)
104
+ response = requests.post(self.base_url, params=params, files=files)
105
+ elif isinstance(image, bytes):
106
+ # Convert bytes to base64
107
+ base64_image = base64.b64encode(image).decode('utf-8')
108
+ files = {
109
+ 'image': (None, base64_image)
110
+ }
111
+ if name:
112
+ files['name'] = (None, name)
113
+ response = requests.post(self.base_url, params=params, files=files)
114
+ else:
115
+ raise ValueError("Invalid image format. Must be file path, URL, base64 string, or bytes")
116
+
117
+ # Check the response
118
+ if response.status_code != 200:
119
+ raise ValueError(f"Upload failed with status {response.status_code}: {response.text}")
120
+
121
+ # Parse the response using Pydantic model
122
+ response_json = response.json()
123
+ return ImgBBResponse.parse_obj(response_json)
124
+
125
+ def upload_file(
126
+ self,
127
+ file_path: Union[str, Path],
128
+ name: Optional[str] = None,
129
+ expiration: Optional[int] = None
130
+ ) -> ImgBBResponse:
131
+ """
132
+ Convenience method to upload an image file.
133
+
134
+ Args:
135
+ file_path: Path to the image file
136
+ name: Optional name for the uploaded file
137
+ expiration: Optional expiration time in seconds (60-15552000)
138
+
139
+ Returns:
140
+ ImgBBResponse containing the parsed upload response from ImgBB
141
+ """
142
+ return self.upload(file_path, name=name, expiration=expiration)
143
+
144
+ def upload_url(
145
+ self,
146
+ image_url: str,
147
+ name: Optional[str] = None,
148
+ expiration: Optional[int] = None
149
+ ) -> ImgBBResponse:
150
+ """
151
+ Convenience method to upload an image from a URL.
152
+
153
+ Args:
154
+ image_url: URL of the image to upload
155
+ name: Optional name for the uploaded file
156
+ expiration: Optional expiration time in seconds (60-15552000)
157
+
158
+ Returns:
159
+ ImgBBResponse containing the parsed upload response from ImgBB
160
+ """
161
+ return self.upload(image_url, name=name, expiration=expiration)
src/services/openai_file_upload.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from dotenv import load_dotenv
3
+ import os
4
+
5
+ load_dotenv()
6
+
7
+ class OpenAIFileUpload:
8
+ def __init__(self):
9
+ self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
10
+
11
+ def upload_image(self, image, purpose) -> str:
12
+ file = self.client.files.create(file=image, purpose=purpose)
13
+ return file.id