simonlee-cb's picture
feat: integrated super resolution
7daa838
raw
history blame
6.22 kB
import asyncio
import httpx
from enum import Enum
from src.utils import image_path_to_uri
from dotenv import load_dotenv
import os
from pydantic import BaseModel, Field
from typing import List
load_dotenv()
class Environment(Enum):
STAGING = "staging"
PRODUCTION = "production"
@property
def base_url(self) -> str:
match self:
case Environment.STAGING:
return "https://serving.hopter.staging.picc.co"
case Environment.PRODUCTION:
return "https://serving.hopter.picc.co"
class RamGroundedSamInput(BaseModel):
text_prompt: str = Field(..., description="The text prompt for the mask generation.")
image_b64: str = Field(..., description="The image in base64 format.")
class RamGroundedSamResult(BaseModel):
mask_b64: str = Field(..., description="The mask image in base64 format.")
class_label: str = Field(..., description="The class label of the mask.")
confidence: float = Field(..., description="The confidence score of the mask.")
bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].")
class MagicReplaceInput(BaseModel):
image: str = Field(..., description="The image in base64 format.")
mask: str = Field(..., description="The mask in base64 format.")
prompt: str = Field(..., description="The prompt for the magic replace.")
class MagicReplaceResult(BaseModel):
base64_image: str = Field(..., description="The edited image in base64 format.")
class SuperResolutionInput(BaseModel):
image_b64: str = Field(..., description="The image in base64 format.")
scale: int = Field(4, description="The scale of the image to upscale to.")
use_face_enhancement: bool = Field(False, description="Whether to use face enhancement.")
class SuperResolutionResult(BaseModel):
scaled_image: str = Field(..., description="The super-resolved image in base64 format.")
class Hopter:
def __init__(
self,
api_key: str,
environment: Environment = Environment.PRODUCTION
):
self.api_key = api_key
self.base_url = environment.base_url
self.client = httpx.Client()
def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult:
print(f"Generating mask with input: {input.text_prompt}")
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": input.model_dump()
},
timeout=None
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output").get("instances")[0]
print("Generated mask.")
return RamGroundedSamResult(**instance)
except httpx.HTTPStatusError as exc:
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
print(f"Magic replacing with input: {input.prompt}")
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": input.model_dump()
},
timeout=None
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output")
print("Magic replaced.")
return MagicReplaceResult(**instance)
except httpx.HTTPStatusError as exc:
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult:
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": input.model_dump()
},
timeout=None
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output")
print("Super-resolutin done")
return SuperResolutionResult(**instance)
except httpx.HTTPStatusError as exc:
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}")
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
input = RamGroundedSamInput(
text_prompt="pole",
image_b64=image_url
)
mask = hopter.generate_mask(input)
return mask.mask_b64
async def test_magic_replace(hopter: Hopter, image_url: str, mask: str, prompt: str) -> str:
input = MagicReplaceInput(
image=image_url,
mask=mask,
prompt=prompt
)
result = hopter.magic_replace(input)
return result.base64_image
async def main():
hopter = Hopter(
api_key=os.getenv("HOPTER_API_KEY"),
environment=Environment.STAGING
)
image_file_path = "./assets/lakeview.jpg"
image_url = image_path_to_uri(image_file_path)
mask = await test_generate_mask(hopter, image_url)
magic_replace_result = await test_magic_replace(hopter, image_url, mask, "remove the pole")
print(magic_replace_result)
if __name__ == "__main__":
asyncio.run(main())