File size: 6,180 Bytes
a21dee1
 
 
c55fe6a
a21dee1
 
 
 
 
 
 
fcb8f25
a21dee1
 
 
 
 
 
 
 
 
 
 
fcb8f25
 
a21dee1
fcb8f25
 
 
a21dee1
 
fcb8f25
a21dee1
 
 
 
fcb8f25
 
 
 
a21dee1
9e822e4
 
 
 
 
fcb8f25
9e822e4
 
 
fcb8f25
7daa838
 
 
fcb8f25
 
 
 
7daa838
 
fcb8f25
 
 
 
7daa838
a21dee1
fcb8f25
a21dee1
 
c55fe6a
a21dee1
c55fe6a
a21dee1
 
c55fe6a
a21dee1
 
 
fcb8f25
c55fe6a
fcb8f25
 
a21dee1
 
 
 
 
 
fcb8f25
 
 
a21dee1
 
fcb8f25
c55fe6a
9e822e4
 
c55fe6a
9e822e4
 
 
fcb8f25
9e822e4
fcb8f25
 
9e822e4
 
 
 
 
 
fcb8f25
 
 
9e822e4
 
 
7daa838
 
 
 
 
 
fcb8f25
7daa838
fcb8f25
 
7daa838
 
 
 
 
 
fcb8f25
 
 
7daa838
 
 
 
9e822e4
fcb8f25
c55fe6a
9e822e4
fcb8f25
 
 
 
 
 
c55fe6a
9e822e4
a21dee1
fcb8f25
a21dee1
 
fcb8f25
a21dee1
 
c55fe6a
 
 
fcb8f25
 
 
c55fe6a
a21dee1
fcb8f25
a21dee1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import asyncio
import httpx
from enum import Enum
from src.utils import image_path_to_uri
from dotenv import load_dotenv
import os
from pydantic import BaseModel, Field
from typing import List

load_dotenv()


class Environment(Enum):
    STAGING = "staging"
    PRODUCTION = "production"

    @property
    def base_url(self) -> str:
        match self:
            case Environment.STAGING:
                return "https://serving.hopter.staging.picc.co"
            case Environment.PRODUCTION:
                return "https://serving.hopter.picc.co"


class RamGroundedSamInput(BaseModel):
    text_prompt: str = Field(
        ..., description="The text prompt for the mask generation."
    )
    image_b64: str = Field(..., description="The image in base64 format.")


class RamGroundedSamResult(BaseModel):
    mask_b64: str = Field(..., description="The mask image in base64 format.")
    class_label: str = Field(..., description="The class label of the mask.")
    confidence: float = Field(..., description="The confidence score of the mask.")
    bbox: List[float] = Field(
        ..., description="The bounding box of the mask in the format [x1, y1, x2, y2]."
    )


class MagicReplaceInput(BaseModel):
    image: str = Field(..., description="The image in base64 format.")
    mask: str = Field(..., description="The mask in base64 format.")
    prompt: str = Field(..., description="The prompt for the magic replace.")


class MagicReplaceResult(BaseModel):
    base64_image: str = Field(..., description="The edited image in base64 format.")


class SuperResolutionInput(BaseModel):
    image_b64: str = Field(..., description="The image in base64 format.")
    scale: int = Field(4, description="The scale of the image to upscale to.")
    use_face_enhancement: bool = Field(
        False, description="Whether to use face enhancement."
    )


class SuperResolutionResult(BaseModel):
    scaled_image: str = Field(
        ..., description="The super-resolved image in base64 format."
    )


class Hopter:
    def __init__(self, api_key: str, environment: Environment = Environment.PRODUCTION):
        self.api_key = api_key
        self.base_url = environment.base_url
        self.client = httpx.Client()

    def generate_mask(self, input: RamGroundedSamInput) -> RamGroundedSamResult:
        print(f"Generating mask with input: {input.text_prompt}")
        try:
            response = self.client.post(
                f"{self.base_url}/api/v1/services/ram-grounded-sam-api/predictions",
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json",
                },
                json={"input": input.model_dump()},
                timeout=None,
            )
            response.raise_for_status()  # Raise an error for bad responses
            instance = response.json().get("output").get("instances")[0]
            print("Generated mask.")
            return RamGroundedSamResult(**instance)
        except httpx.HTTPStatusError as exc:
            print(
                f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
            )
        except Exception as exc:
            print(f"An unexpected error occurred: {exc}")

    def magic_replace(self, input: MagicReplaceInput) -> MagicReplaceResult:
        print(f"Magic replacing with input: {input.prompt}")
        try:
            response = self.client.post(
                f"{self.base_url}/api/v1/services/sdxl-magic-replace/predictions",
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json",
                },
                json={"input": input.model_dump()},
                timeout=None,
            )
            response.raise_for_status()  # Raise an error for bad responses
            instance = response.json().get("output")
            print("Magic replaced.")
            return MagicReplaceResult(**instance)
        except httpx.HTTPStatusError as exc:
            print(
                f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
            )
        except Exception as exc:
            print(f"An unexpected error occurred: {exc}")

    def super_resolution(self, input: SuperResolutionInput) -> SuperResolutionResult:
        try:
            response = self.client.post(
                f"{self.base_url}/api/v1/services/super-resolution-esrgan/predictions",
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json",
                },
                json={"input": input.model_dump()},
                timeout=None,
            )
            response.raise_for_status()  # Raise an error for bad responses
            instance = response.json().get("output")
            print("Super-resolutin done")
            return SuperResolutionResult(**instance)
        except httpx.HTTPStatusError as exc:
            print(
                f"HTTP error occurred: {exc.response.status_code} - {exc.response.text}"
            )
        except Exception as exc:
            print(f"An unexpected error occurred: {exc}")


async def test_generate_mask(hopter: Hopter, image_url: str) -> str:
    input = RamGroundedSamInput(text_prompt="pole", image_b64=image_url)
    mask = hopter.generate_mask(input)
    return mask.mask_b64


async def test_magic_replace(
    hopter: Hopter, image_url: str, mask: str, prompt: str
) -> str:
    input = MagicReplaceInput(image=image_url, mask=mask, prompt=prompt)
    result = hopter.magic_replace(input)
    return result.base64_image


async def main():
    hopter = Hopter(
        api_key=os.getenv("HOPTER_API_KEY"), environment=Environment.STAGING
    )
    image_file_path = "./assets/lakeview.jpg"
    image_url = image_path_to_uri(image_file_path)

    mask = await test_generate_mask(hopter, image_url)
    magic_replace_result = await test_magic_replace(
        hopter, image_url, mask, "remove the pole"
    )
    print(magic_replace_result)


if __name__ == "__main__":
    asyncio.run(main())