Spaces:
Running
Running
import asyncio | |
import httpx | |
from enum import Enum | |
from src.utils import image_path_to_uri | |
from dotenv import load_dotenv | |
import os | |
from pydantic import BaseModel, Field | |
from typing import List | |
load_dotenv() | |
class Environment(Enum): | |
STAGING = "staging" | |
PRODUCTION = "production" | |
def base_url(self) -> str: | |
match self: | |
case Environment.STAGING: | |
return "https://serving.hopter.staging.picc.co" | |
case Environment.PRODUCTION: | |
return "https://serving.hopter.picc.co" | |
class RamGroundedSamInput(BaseModel): | |
text_prompt: str = Field(..., description="The text prompt for the mask generation.") | |
image_b64: str = Field(..., description="The image in base64 format.") | |
class RamGroundedSamResult(BaseModel): | |
mask_b64: str = Field(..., description="The mask image in base64 format.") | |
class_label: str = Field(..., description="The class label of the mask.") | |
confidence: float = Field(..., description="The confidence score of the mask.") | |
bbox: List[float] = Field(..., description="The bounding box of the mask in the format [x1, y1, x2, y2].") | |
class MagicReplaceInput(BaseModel): | |
image: str = Field(..., description="The image in base64 format.") | |
mask: str = Field(..., description="The mask in base64 format.") | |
prompt: str = Field(..., description="The prompt for the magic replace.") | |
class MagicReplaceResult(BaseModel): | |
base64_image: str = Field(..., description="The edited image in base64 format.") | |
class SuperResolutionInput(BaseModel): | |
image_b64: str = Field(..., description="The image in base64 format.") | |
scale: int = Field(4, description="The scale of the image to upscale to.") | |
use_face_enhancement: bool = Field(False, description="Whether to use face enhancement.") | |
class SuperResolutionResult(BaseModel): | |
scaled_image: str = Field(..., description="The super-resolved image in base64 format.") | |
class Hopter: | |
def __init__( | |
self, | |
api_key: str, | |
environment: Environment = Environment.PRODUCTION | |
): | |
self.api_key = api_key | |
self.base_url = environment.base_url | |
self.client = httpx.Client() | |
def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult: | |
print(f"Generating mask with input: {input.text_prompt}") | |
try: | |
response = self.client.post( | |
f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions", | |
headers={ | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"input": input.model_dump() | |
}, | |
timeout=None | |
) | |
response.raise_for_status() # Raise an error for bad responses | |
instance = response.json().get("output").get("instances")[0] | |
print("Generated mask.") | |
return RamGroundedSamResult(**instance) | |
except httpx.HTTPStatusError as exc: | |
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}") | |
except Exception as exc: | |
print(f"An unexpected error occurred: {exc}") | |
def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult: | |
print(f"Magic replacing with input: {input.prompt}") | |
try: | |
response = self.client.post( | |
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions", | |
headers={ | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"input": input.model_dump() | |
}, | |
timeout=None | |
) | |
response.raise_for_status() # Raise an error for bad responses | |
instance = response.json().get("output") | |
print("Magic replaced.") | |
return MagicReplaceResult(**instance) | |
except httpx.HTTPStatusError as exc: | |
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}") | |
except Exception as exc: | |
print(f"An unexpected error occurred: {exc}") | |
def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult: | |
try: | |
response = self.client.post( | |
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions", | |
headers={ | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"input": input.model_dump() | |
}, | |
timeout=None | |
) | |
response.raise_for_status() # Raise an error for bad responses | |
instance = response.json().get("output") | |
print("Super-resolutin done") | |
return SuperResolutionResult(**instance) | |
except httpx.HTTPStatusError as exc: | |
print(f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}") | |
except Exception as exc: | |
print(f"An unexpected error occurred: {exc}") | |
async def test_generate_mask(hopter: Hopter, image_url: str) -> str: | |
input = RamGroundedSamInput( | |
text_prompt="pole", | |
image_b64=image_url | |
) | |
mask = hopter.generate_mask(input) | |
return mask.mask_b64 | |
async def test_magic_replace(hopter: Hopter, image_url: str, mask: str, prompt: str) -> str: | |
input = MagicReplaceInput( | |
image=image_url, | |
mask=mask, | |
prompt=prompt | |
) | |
result = hopter.magic_replace(input) | |
return result.base64_image | |
async def main(): | |
hopter = Hopter( | |
api_key=os.getenv("HOPTER_API_KEY"), | |
environment=Environment.STAGING | |
) | |
image_file_path = "./assets/lakeview.jpg" | |
image_url = image_path_to_uri(image_file_path) | |
mask = await test_generate_mask(hopter, image_url) | |
magic_replace_result = await test_magic_replace(hopter, image_url, mask, "remove the pole") | |
print(magic_replace_result) | |
if __name__ == "__main__": | |
asyncio.run(main()) | |