simonlee-cb commited on
Commit
7daa838
·
1 Parent(s): c55fe6a

feat: integrated super resolution

Browse files
src/agents/mask_generation_agent.py CHANGED
@@ -6,7 +6,7 @@ import asyncio
6
  from dataclasses import dataclass
7
  import logfire
8
  from src.services.generate_mask import GenerateMaskService
9
- from src.hopter.client import Hopter, Environment, MagicReplaceInput
10
  from src.services.image_uploader import ImageUploader
11
  from src.utils import image_path_to_uri
12
 
@@ -67,25 +67,36 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
67
  """
68
  Edit an object in the image.
69
  """
70
- print("Invoking edit_object tool")
71
  edit_instruction = ctx.deps.edit_instruction
72
  image_url = ctx.deps.image_url
73
  mask_service = ctx.deps.mask_service
74
  hopter_client = ctx.deps.hopter_client
75
 
76
  # Generate mask
77
- print("Generating mask")
78
  mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
79
  mask = mask_service.generate_mask(mask_instruction, image_url)
80
 
81
  # Magic replace
82
  input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
83
  result = hopter_client.magic_replace(input)
84
- print("Exiting edit_object tool: ", result)
85
  uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
86
  uploaded_image = uploader.upload_url(result.base64_image)
87
  return EditImageResult(edited_image_url=uploaded_image.data.url)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  async def main():
90
  image_file_path = "./assets/lakeview.jpg"
91
  image_url = image_path_to_uri(image_file_path)
 
6
  from dataclasses import dataclass
7
  import logfire
8
  from src.services.generate_mask import GenerateMaskService
9
+ from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
10
  from src.services.image_uploader import ImageUploader
11
  from src.utils import image_path_to_uri
12
 
 
67
  """
68
  Edit an object in the image.
69
  """
 
70
  edit_instruction = ctx.deps.edit_instruction
71
  image_url = ctx.deps.image_url
72
  mask_service = ctx.deps.mask_service
73
  hopter_client = ctx.deps.hopter_client
74
 
75
  # Generate mask
 
76
  mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
77
  mask = mask_service.generate_mask(mask_instruction, image_url)
78
 
79
  # Magic replace
80
  input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
81
  result = hopter_client.magic_replace(input)
 
82
  uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
83
  uploaded_image = uploader.upload_url(result.base64_image)
84
  return EditImageResult(edited_image_url=uploaded_image.data.url)
85
 
86
+ @mask_generation_agent.tool
87
+ async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
88
+ """
89
+ run super resolution, upscale, or enhance the image
90
+ """
91
+ image_url = ctx.deps.image_url
92
+ hopter_client = ctx.deps.hopter_client
93
+
94
+ input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
95
+ result = hopter_client.super_resolution(input)
96
+ uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
97
+ uploaded_image = uploader.upload_url(result.scaled_image)
98
+ return EditImageResult(edited_image_url=uploaded_image.data.url)
99
+
100
  async def main():
101
  image_file_path = "./assets/lakeview.jpg"
102
  image_url = image_path_to_uri(image_file_path)
src/hopter/client.py CHANGED
@@ -39,6 +39,14 @@ class MagicReplaceInput(BaseModel):
39
  class MagicReplaceResult(BaseModel):
40
  base64_image: str = Field(..., description="The edited image in base64 format.")
41
 
 
 
 
 
 
 
 
 
42
  class Hopter:
43
  def __init__(
44
  self,
@@ -95,6 +103,29 @@ class Hopter:
95
  except Exception as exc:
96
  print(f"An unexpected error occurred: {exc}")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
99
  input = RamGroundedSamInput(
100
  text_prompt="pole",
 
39
  class MagicReplaceResult(BaseModel):
40
  base64_image: str = Field(..., description="The edited image in base64 format.")
41
 
42
+ class SuperResolutionInput(BaseModel):
43
+ image_b64: str = Field(..., description="The image in base64 format.")
44
+ scale: int = Field(4, description="The scale of the image to upscale to.")
45
+ use_face_enhancement: bool = Field(False, description="Whether to use face enhancement.")
46
+
47
+ class SuperResolutionResult(BaseModel):
48
+ scaled_image: str = Field(..., description="The super-resolved image in base64 format.")
49
+
50
  class Hopter:
51
  def __init__(
52
  self,
 
103
  except Exception as exc:
104
  print(f"An unexpected error occurred: {exc}")
105
 
106
+ def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult:
107
+ try:
108
+ response = self.client.post(
109
+ f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
110
+ headers={
111
+ "Authorization": f"Bearer {self.api_key}",
112
+ "Content-Type": "application/json"
113
+ },
114
+ json={
115
+ "input": input.model_dump()
116
+ },
117
+ timeout=None
118
+ )
119
+ response.raise_for_status() # Raise an error for bad responses
120
+ instance = response.json().get("output")
121
+ print("Super-resolutin done")
122
+ return SuperResolutionResult(**instance)
123
+ except httpx.HTTPStatusError as exc:
124
+ print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
125
+ except Exception as exc:
126
+ print(f"An unexpected error occurred: {exc}")
127
+
128
+
129
  async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
130
  input = RamGroundedSamInput(
131
  text_prompt="pole",