import asyncio import httpx from enum import Enum from src.utils import image_path_to_uri from dotenv import load_dotenv import os from pydantic import BaseModel, Field from typing import List load_dotenv() class Environment(Enum): STAGING = "staging" PRODUCTION = "production" @property def base_url(self) -> str: match self: case Environment.STAGING: return "https://serving.hopter.staging.picc.co" case Environment.PRODUCTION: return "https://serving.hopter.picc.co" class RamGroundedSamInput(BaseModel): text_prompt: str = Field( ..., description="The text prompt for the mask generation." ) image_b64: str = Field(..., description="The image in base64 format.") class RamGroundedSamResult(BaseModel): mask_b64: str = Field(..., description="The mask image in base64 format.") class_label: str = Field(..., description="The class label of the mask.") confidence: float = Field(..., description="The confidence score of the mask.") bbox: List[float] = Field( ..., description="The bounding box of the mask in the format [x1, y1, x2, y2]." ) class MagicReplaceInput(BaseModel): image: str = Field(..., description="The image in base64 format.") mask: str = Field(..., description="The mask in base64 format.") prompt: str = Field(..., description="The prompt for the magic replace.") class MagicReplaceResult(BaseModel): base64_image: str = Field(..., description="The edited image in base64 format.") class SuperResolutionInput(BaseModel): image_b64: str = Field(..., description="The image in base64 format.") scale: int = Field(4, description="The scale of the image to upscale to.") use_face_enhancement: bool = Field( False, description="Whether to use face enhancement." ) class SuperResolutionResult(BaseModel): scaled_image: str = Field( ..., description="The super-resolved image in base64 format." ) class Hopter: def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION): self.api_key = api_key self.base_url = environment.base_url self.client = httpx.Client() def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult: print(f"Generating mask with input: {input.text_prompt}") try: response = self.client.post( f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={"input": input.model_dump()}, timeout=None, ) response.raise_for_status() # Raise an error for bad responses instance = response.json().get("output").get("instances")[0] print("Generated mask.") return RamGroundedSamResult(**instance) except httpx.HTTPStatusError as exc: print( f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" ) except Exception as exc: print(f"An unexpected error occurred: {exc}") def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult: print(f"Magic replacing with input: {input.prompt}") try: response = self.client.post( f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={"input": input.model_dump()}, timeout=None, ) response.raise_for_status() # Raise an error for bad responses instance = response.json().get("output") print("Magic replaced.") return MagicReplaceResult(**instance) except httpx.HTTPStatusError as exc: print( f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" ) except Exception as exc: print(f"An unexpected error occurred: {exc}") def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult: try: response = self.client.post( f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={"input": input.model_dump()}, timeout=None, ) response.raise_for_status() # Raise an error for bad responses instance = response.json().get("output") print("Super-resolutin done") return SuperResolutionResult(**instance) except httpx.HTTPStatusError as exc: print( f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}" ) except Exception as exc: print(f"An unexpected error occurred: {exc}") async def test_generate_mask(hopter: Hopter, image_url: str) -> str: input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url) mask = hopter.generate_mask(input) return mask.mask_b64 async def test_magic_replace( hopter: Hopter, image_url: str, mask: str, prompt: str ) -> str: input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt) result = hopter.magic_replace(input) return result.base64_image async def main(): hopter = Hopter( api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING ) image_file_path = "./assets/lakeview.jpg" image_url = image_path_to_uri(image_file_path) mask = await test_generate_mask(hopter, image_url) magic_replace_result = await test_magic_replace( hopter, image_url, mask, "remove the pole" ) print(magic_replace_result) if __name__ == "__main__": asyncio.run(main())