Spaces:
Running
Running
File size: 6,180 Bytes
a21dee1 c55fe6a a21dee1 fcb8f25 a21dee1 fcb8f25 a21dee1 fcb8f25 a21dee1 fcb8f25 a21dee1 fcb8f25 a21dee1 9e822e4 fcb8f25 9e822e4 fcb8f25 7daa838 fcb8f25 7daa838 fcb8f25 7daa838 a21dee1 fcb8f25 a21dee1 c55fe6a a21dee1 c55fe6a a21dee1 c55fe6a a21dee1 fcb8f25 c55fe6a fcb8f25 a21dee1 fcb8f25 a21dee1 fcb8f25 c55fe6a 9e822e4 c55fe6a 9e822e4 fcb8f25 9e822e4 fcb8f25 9e822e4 fcb8f25 9e822e4 7daa838 fcb8f25 7daa838 fcb8f25 7daa838 fcb8f25 7daa838 9e822e4 fcb8f25 c55fe6a 9e822e4 fcb8f25 c55fe6a 9e822e4 a21dee1 fcb8f25 a21dee1 fcb8f25 a21dee1 c55fe6a fcb8f25 c55fe6a a21dee1 fcb8f25 a21dee1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import asyncio
import httpx
from enum import Enum
from src.utils import image_path_to_uri
from dotenv import load_dotenv
import os
from pydantic import BaseModel, Field
from typing import List
load_dotenv()
class Environment(Enum):
STAGING = "staging"
PRODUCTION = "production"
@property
def base_url(self) -> str:
match self:
case Environment.STAGING:
return "https://serving.hopter.staging.picc.co"
case Environment.PRODUCTION:
return "https://serving.hopter.picc.co"
class RamGroundedSamInput(BaseModel):
text_prompt: str = Field(
..., description="The text prompt for the mask generation."
)
image_b64: str = Field(..., description="The image in base64 format.")
class RamGroundedSamResult(BaseModel):
mask_b64: str = Field(..., description="The mask image in base64 format.")
class_label: str = Field(..., description="The class label of the mask.")
confidence: float = Field(..., description="The confidence score of the mask.")
bbox: List[float] = Field(
..., description="The bounding box of the mask in the format [x1, y1, x2, y2]."
)
class MagicReplaceInput(BaseModel):
image: str = Field(..., description="The image in base64 format.")
mask: str = Field(..., description="The mask in base64 format.")
prompt: str = Field(..., description="The prompt for the magic replace.")
class MagicReplaceResult(BaseModel):
base64_image: str = Field(..., description="The edited image in base64 format.")
class SuperResolutionInput(BaseModel):
image_b64: str = Field(..., description="The image in base64 format.")
scale: int = Field(4, description="The scale of the image to upscale to.")
use_face_enhancement: bool = Field(
False, description="Whether to use face enhancement."
)
class SuperResolutionResult(BaseModel):
scaled_image: str = Field(
..., description="The super-resolved image in base64 format."
)
class Hopter:
def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION):
self.api_key = api_key
self.base_url = environment.base_url
self.client = httpx.Client()
def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult:
print(f"Generating mask with input: {input.text_prompt}")
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"input": input.model_dump()},
timeout=None,
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output").get("instances")[0]
print("Generated mask.")
return RamGroundedSamResult(**instance)
except httpx.HTTPStatusError as exc:
print(
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
)
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
print(f"Magic replacing with input: {input.prompt}")
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"input": input.model_dump()},
timeout=None,
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output")
print("Magic replaced.")
return MagicReplaceResult(**instance)
except httpx.HTTPStatusError as exc:
print(
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
)
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult:
try:
response = self.client.post(
f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"input": input.model_dump()},
timeout=None,
)
response.raise_for_status() # Raise an error for bad responses
instance = response.json().get("output")
print("Super-resolutin done")
return SuperResolutionResult(**instance)
except httpx.HTTPStatusError as exc:
print(
f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
)
except Exception as exc:
print(f"An unexpected error occurred: {exc}")
async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url)
mask = hopter.generate_mask(input)
return mask.mask_b64
async def test_magic_replace(
hopter: Hopter, image_url: str, mask: str, prompt: str
) -> str:
input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt)
result = hopter.magic_replace(input)
return result.base64_image
async def main():
hopter = Hopter(
api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING
)
image_file_path = "./assets/lakeview.jpg"
image_url = image_path_to_uri(image_file_path)
mask = await test_generate_mask(hopter, image_url)
magic_replace_result = await test_magic_replace(
hopter, image_url, mask, "remove the pole"
)
print(magic_replace_result)
if __name__ == "__main__":
asyncio.run(main())
|