|
import os |
|
from typing import Optional, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import autocast |
|
from diffusers import PNDMScheduler, LMSDiscreteScheduler |
|
from PIL import Image |
|
from cog import BasePredictor, Input, Path |
|
|
|
from image_to_image import ( |
|
StableDiffusionImg2ImgPipeline, |
|
preprocess_init_image, |
|
preprocess_mask, |
|
) |
|
|
|
def patch_conv(**patch): |
|
cls = torch.nn.Conv2d |
|
init = cls.__init__ |
|
def __init__(self, *args, **kwargs): |
|
return init(self, *args, **kwargs, **patch) |
|
cls.__init__ = __init__ |
|
|
|
patch_conv(padding_mode='circular') |
|
|
|
MODEL_CACHE = "diffusers-cache" |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
print("Loading pipeline...") |
|
scheduler = PNDMScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
|
) |
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
"CompVis/stable-diffusion-v1-4", |
|
scheduler=scheduler, |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
cache_dir=MODEL_CACHE, |
|
local_files_only=True, |
|
).to("cuda") |
|
|
|
@torch.inference_mode() |
|
@torch.cuda.amp.autocast() |
|
def predict( |
|
self, |
|
prompt: str = Input(description="Input prompt", default=""), |
|
width: int = Input( |
|
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", |
|
choices=[128, 256, 512, 768, 1024], |
|
default=512, |
|
), |
|
height: int = Input( |
|
description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", |
|
choices=[128, 256, 512, 768, 1024], |
|
default=512, |
|
), |
|
init_image: Path = Input( |
|
description="Inital image to generate variations of. Will be resized to the specified width and height", |
|
default=None, |
|
), |
|
mask: Path = Input( |
|
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7", |
|
default=None, |
|
), |
|
prompt_strength: float = Input( |
|
description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image", |
|
default=0.8, |
|
), |
|
num_outputs: int = Input( |
|
description="Number of images to output", choices=[1, 4], default=1 |
|
), |
|
num_inference_steps: int = Input( |
|
description="Number of denoising steps", ge=1, le=500, default=50 |
|
), |
|
guidance_scale: float = Input( |
|
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 |
|
), |
|
seed: int = Input( |
|
description="Random seed. Leave blank to randomize the seed", default=None |
|
), |
|
) -> List[Path]: |
|
"""Run a single prediction on the model""" |
|
if seed is None: |
|
seed = int.from_bytes(os.urandom(2), "big") |
|
print(f"Using seed: {seed}") |
|
|
|
if width == height == 1024: |
|
raise ValueError( |
|
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height." |
|
) |
|
|
|
if init_image: |
|
init_image = Image.open(init_image).convert("RGB") |
|
init_image = preprocess_init_image(init_image, width, height).to("cuda") |
|
|
|
|
|
scheduler = PNDMScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
|
) |
|
else: |
|
|
|
scheduler = LMSDiscreteScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
|
) |
|
|
|
self.pipe.scheduler = scheduler |
|
|
|
if mask: |
|
mask = Image.open(mask).convert("RGB") |
|
mask = preprocess_mask(mask, width, height).to("cuda") |
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
output = self.pipe( |
|
prompt=[prompt] * num_outputs if prompt is not None else None, |
|
init_image=init_image, |
|
mask=mask, |
|
width=width, |
|
height=height, |
|
prompt_strength=prompt_strength, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
num_inference_steps=num_inference_steps, |
|
) |
|
if any(output["nsfw_content_detected"]): |
|
raise Exception("NSFW content detected, please try a different prompt") |
|
|
|
output_paths = [] |
|
for i, sample in enumerate(output["sample"]): |
|
output_path = f"/tmp/out-{i}.png" |
|
sample.save(output_path) |
|
output_paths.append(Path(output_path)) |
|
|
|
return output_paths |
|
|