Spaces:
Running
Running
Commit
·
7daa838
1
Parent(s):
c55fe6a
feat: integrated super resolution
Browse files- src/agents/mask_generation_agent.py +15 -4
- src/hopter/client.py +31 -0
src/agents/mask_generation_agent.py
CHANGED
@@ -6,7 +6,7 @@ import asyncio
|
|
6 |
from dataclasses import dataclass
|
7 |
import logfire
|
8 |
from src.services.generate_mask import GenerateMaskService
|
9 |
-
from src.hopter.client import Hopter, Environment, MagicReplaceInput
|
10 |
from src.services.image_uploader import ImageUploader
|
11 |
from src.utils import image_path_to_uri
|
12 |
|
@@ -67,25 +67,36 @@ async def edit_object(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
|
67 |
"""
|
68 |
Edit an object in the image.
|
69 |
"""
|
70 |
-
print("Invoking edit_object tool")
|
71 |
edit_instruction = ctx.deps.edit_instruction
|
72 |
image_url = ctx.deps.image_url
|
73 |
mask_service = ctx.deps.mask_service
|
74 |
hopter_client = ctx.deps.hopter_client
|
75 |
|
76 |
# Generate mask
|
77 |
-
print("Generating mask")
|
78 |
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
79 |
mask = mask_service.generate_mask(mask_instruction, image_url)
|
80 |
|
81 |
# Magic replace
|
82 |
input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
|
83 |
result = hopter_client.magic_replace(input)
|
84 |
-
print("Exiting edit_object tool: ", result)
|
85 |
uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
|
86 |
uploaded_image = uploader.upload_url(result.base64_image)
|
87 |
return EditImageResult(edited_image_url=uploaded_image.data.url)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
async def main():
|
90 |
image_file_path = "./assets/lakeview.jpg"
|
91 |
image_url = image_path_to_uri(image_file_path)
|
|
|
6 |
from dataclasses import dataclass
|
7 |
import logfire
|
8 |
from src.services.generate_mask import GenerateMaskService
|
9 |
+
from src.hopter.client import Hopter, Environment, MagicReplaceInput, SuperResolutionInput
|
10 |
from src.services.image_uploader import ImageUploader
|
11 |
from src.utils import image_path_to_uri
|
12 |
|
|
|
67 |
"""
|
68 |
Edit an object in the image.
|
69 |
"""
|
|
|
70 |
edit_instruction = ctx.deps.edit_instruction
|
71 |
image_url = ctx.deps.image_url
|
72 |
mask_service = ctx.deps.mask_service
|
73 |
hopter_client = ctx.deps.hopter_client
|
74 |
|
75 |
# Generate mask
|
|
|
76 |
mask_instruction = mask_service.get_mask_generation_instruction(edit_instruction, image_url)
|
77 |
mask = mask_service.generate_mask(mask_instruction, image_url)
|
78 |
|
79 |
# Magic replace
|
80 |
input = MagicReplaceInput(image=image_url, mask=mask, prompt=mask_instruction.target_caption)
|
81 |
result = hopter_client.magic_replace(input)
|
|
|
82 |
uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
|
83 |
uploaded_image = uploader.upload_url(result.base64_image)
|
84 |
return EditImageResult(edited_image_url=uploaded_image.data.url)
|
85 |
|
86 |
+
@mask_generation_agent.tool
|
87 |
+
async def super_resolution(ctx: RunContext[ImageEditDeps]) -> EditImageResult:
|
88 |
+
"""
|
89 |
+
run super resolution, upscale, or enhance the image
|
90 |
+
"""
|
91 |
+
image_url = ctx.deps.image_url
|
92 |
+
hopter_client = ctx.deps.hopter_client
|
93 |
+
|
94 |
+
input = SuperResolutionInput(image_b64=image_url, scale=4, use_face_enhancement=False)
|
95 |
+
result = hopter_client.super_resolution(input)
|
96 |
+
uploader = ImageUploader(os.environ.get("IMG_BB_API_KEY"))
|
97 |
+
uploaded_image = uploader.upload_url(result.scaled_image)
|
98 |
+
return EditImageResult(edited_image_url=uploaded_image.data.url)
|
99 |
+
|
100 |
async def main():
|
101 |
image_file_path = "./assets/lakeview.jpg"
|
102 |
image_url = image_path_to_uri(image_file_path)
|
src/hopter/client.py
CHANGED
@@ -39,6 +39,14 @@ class MagicReplaceInput(BaseModel):
|
|
39 |
class MagicReplaceResult(BaseModel):
|
40 |
base64_image: str = Field(..., description="The edited image in base64 format.")
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
class Hopter:
|
43 |
def __init__(
|
44 |
self,
|
@@ -95,6 +103,29 @@ class Hopter:
|
|
95 |
except Exception as exc:
|
96 |
print(f"An unexpected error occurred: {exc}")
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
|
99 |
input = RamGroundedSamInput(
|
100 |
text_prompt="pole",
|
|
|
39 |
class MagicReplaceResult(BaseModel):
|
40 |
base64_image: str = Field(..., description="The edited image in base64 format.")
|
41 |
|
42 |
+
class SuperResolutionInput(BaseModel):
|
43 |
+
image_b64: str = Field(..., description="The image in base64 format.")
|
44 |
+
scale: int = Field(4, description="The scale of the image to upscale to.")
|
45 |
+
use_face_enhancement: bool = Field(False, description="Whether to use face enhancement.")
|
46 |
+
|
47 |
+
class SuperResolutionResult(BaseModel):
|
48 |
+
scaled_image: str = Field(..., description="The super-resolved image in base64 format.")
|
49 |
+
|
50 |
class Hopter:
|
51 |
def __init__(
|
52 |
self,
|
|
|
103 |
except Exception as exc:
|
104 |
print(f"An unexpected error occurred: {exc}")
|
105 |
|
106 |
+
def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult:
|
107 |
+
try:
|
108 |
+
response = self.client.post(
|
109 |
+
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
|
110 |
+
headers={
|
111 |
+
"Authorization": f"Bearer {self.api_key}",
|
112 |
+
"Content-Type": "application/json"
|
113 |
+
},
|
114 |
+
json={
|
115 |
+
"input": input.model_dump()
|
116 |
+
},
|
117 |
+
timeout=None
|
118 |
+
)
|
119 |
+
response.raise_for_status() # Raise an error for bad responses
|
120 |
+
instance = response.json().get("output")
|
121 |
+
print("Super-resolutin done")
|
122 |
+
return SuperResolutionResult(**instance)
|
123 |
+
except httpx.HTTPStatusError as exc:
|
124 |
+
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
|
125 |
+
except Exception as exc:
|
126 |
+
print(f"An unexpected error occurred: {exc}")
|
127 |
+
|
128 |
+
|
129 |
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
|
130 |
input = RamGroundedSamInput(
|
131 |
text_prompt="pole",
|