File size: 4,248 Bytes
266dbe0 33596ca 266dbe0 d8457bc 266dbe0 d8457bc 266dbe0 33596ca 266dbe0 95b0167 266dbe0 95b0167 266dbe0 95b0167 266dbe0 95b0167 266dbe0 95b0167 266dbe0 d8457bc 266dbe0 d8457bc 266dbe0 d8457bc 266dbe0 578d7dc 266dbe0 95b0167 d8457bc 95b0167 d8457bc 266dbe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from torchvision import transforms
from config import Args
from pydantic import BaseModel, Field
from util import ParamsModel
from PIL import Image
from pipelines.pix2pix.pix2pix_turbo import Pix2Pix_Turbo
from pipelines.utils.canny_gpu import ScharrOperator
default_prompt = "close-up photo of the joker"
page_content = """
<h1 class="text-3xl font-bold">Real-Time pix2pix_turbo</h1>
<h3 class="text-xl font-bold">pix2pix turbo</h3>
<p class="text-sm">
This demo showcases
<a
href="https://github.com/GaParmar/img2img-turbo"
target="_blank"
class="text-blue-500 underline hover:no-underline">One-Step Image Translation with Text-to-Image Models
</a>
</p>
<p class="text-sm text-gray-500">
Web app <a href="https://github.com/radames/Real-Time-Latent-Consistency-Model" target="_blank" class="text-blue-500 underline hover:no-underline">
Real-Time Latent Consistency Models
</a>
</p>
"""
class Pipeline:
class Info(BaseModel):
name: str = "img2img"
title: str = "Image-to-Image SDXL"
description: str = "Generates an image from a text prompt"
input_mode: str = "image"
page_content: str = page_content
class InputParams(ParamsModel):
prompt: str = Field(
default_prompt,
title="Prompt",
field="textarea",
id="prompt",
)
width: int = Field(
512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
)
height: int = Field(
512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
)
seed: int = Field(
2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
)
noise_r: float = Field(
1.0,
min=0.01,
max=3.0,
step=0.001,
title="Noise R",
field="range",
hide=True,
id="noise_r",
)
deterministic: bool = Field(
True,
hide=True,
title="Deterministic",
field="checkbox",
id="deterministic",
)
canny_low_threshold: float = Field(
0.0,
min=0,
max=1.0,
step=0.001,
title="Canny Low Threshold",
field="range",
hide=True,
id="canny_low_threshold",
)
canny_high_threshold: float = Field(
1.0,
min=0,
max=1.0,
step=0.001,
title="Canny High Threshold",
field="range",
hide=True,
id="canny_high_threshold",
)
debug_canny: bool = Field(
False,
title="Debug Canny",
field="checkbox",
hide=True,
id="debug_canny",
)
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
self.model = Pix2Pix_Turbo("edge_to_image")
self.canny_torch = ScharrOperator(device=device)
self.device = device
self.last_time = 0.0
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
canny_pil, canny_tensor = self.canny_torch(
params.image,
params.canny_low_threshold,
params.canny_high_threshold,
output_type="pil,tensor",
)
torch.manual_seed(params.seed)
noise = torch.randn(
(1, 4, params.width // 8, params.height // 8), device=self.device
)
canny_tensor = torch.cat((canny_tensor, canny_tensor, canny_tensor), dim=1)
output_image = self.model(
canny_tensor,
params.prompt,
params.deterministic,
params.noise_r,
noise,
)
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
result_image = output_pil
if params.debug_canny:
# paste control_image on top of result_image
w0, h0 = (200, 200)
control_image = canny_pil.resize((w0, h0))
w1, h1 = result_image.size
result_image.paste(control_image, (w1 - w0, h1 - h0))
return result_image
|