|
from typing import Any |
|
import torch, base64 |
|
from PIL import Image |
|
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler |
|
from diffusers.utils import load_image |
|
from io import BytesIO |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v11p_sd21", torch_dtype=torch.float16) |
|
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16) |
|
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe.enable_model_cpu_offload() |
|
|
|
def __call__(self, data): |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
date (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
params = data.pop("parameters", data) |
|
prompt = params.get("prompt") |
|
negative_prompt = params.get("negative_prompt") |
|
|
|
def resize_image(input_image: Image, resolution: int): |
|
input_image = input_image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / min(H, W) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
orriginal_qr_code_image = load_image(inputs) |
|
img_path = 'https://images.squarespace-cdn.com/content/v1/59413d96e6f2e1c6837c7ecd/1536503659130-R84NUPOY4QPQTEGCTSAI/15fe1e62172035.5a87280d713e4.png' |
|
|
|
|
|
init_image = load_image(img_path) |
|
condition_image = resize_image(orriginal_qr_code_image, 768) |
|
init_image = resize_image(init_image, 768) |
|
generator = torch.manual_seed(123121231) |
|
image = self.pipe(prompt=prompt or "a bilboard in NYC with a qrcode", |
|
negative_prompt=negative_prompt or "ugly, disfigured, low quality, blurry, nsfw, worst quality, illustration, drawing", |
|
image=init_image, |
|
control_image=condition_image, |
|
width=768, |
|
height=768, |
|
guidance_scale=20, |
|
controlnet_conditioning_scale=2.5, |
|
generator=generator, |
|
strength=0.9, |
|
num_inference_steps=150, |
|
) |
|
|
|
image = image.images[0] |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
return {"image": img_str.decode()} |