|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from copy import deepcopy |
|
from torchvision.transforms.functional import rgb_to_grayscale |
|
import segmentation_models_pytorch as smp |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
class SDGrayInpaintConfig(PretrainedConfig): |
|
model_type = "sd_gray_inpaint" |
|
def __init__( |
|
self, |
|
base_model="stabilityai/stable-diffusion-2-inpainting", |
|
height=512, |
|
width=512, |
|
**kwargs |
|
): |
|
self.base_model=base_model |
|
self.height=height |
|
self.width=width |
|
super().__init__(**kwargs) |
|
|
|
class SDGrayInpaintModel(PreTrainedModel): |
|
config_class = SDGrayInpaintConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model) |
|
self.mask_predictor = smp.Unet( |
|
encoder_name="mit_b4", |
|
encoder_weights="imagenet", |
|
in_channels=3, |
|
classes=1, |
|
) |
|
self.image_processor = pipe.image_processor |
|
self.scheduler = pipe.scheduler |
|
self.unet = pipe.unet |
|
self.vae = pipe.vae |
|
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024)) |
|
self.height=config.height |
|
self.width=config.width |
|
|
|
def forward( |
|
self, |
|
images_gray_masked, |
|
masks=None, |
|
num_inference_steps=250, |
|
seed=42, |
|
input_type='pil', |
|
output_type='pil' |
|
): |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
if input_type=='pil': |
|
images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float() |
|
elif input_type=='pt': |
|
images_gray_masked=images_gray_masked |
|
else: |
|
raise ValueError('unsupported input_type') |
|
images_gray_masked = images_gray_masked.to(self.vae.device) |
|
if masks is None: |
|
masks_logits = self.mask_predictor(images_gray_masked) |
|
masks = (torch.sigmoid(masks_logits)>0.5)*1. |
|
masks = masks.float().to(self.vae.device) |
|
B, C, H, W = images_gray_masked.shape |
|
prompt_embeds = self.prompt_embeds.repeat(B,1,1) |
|
|
|
scheduler = deepcopy(self.scheduler) |
|
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device) |
|
masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor |
|
mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size)) |
|
latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma |
|
for t in scheduler.timesteps: |
|
latents = scheduler.scale_model_input(latents, t) |
|
latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1) |
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0] |
|
latents = scheduler.step(noise_pred, t, latents)[0] |
|
latents = latents / self.vae.config.scaling_factor |
|
images_gray_restored = self.vae.decode(latents.detach())[0] |
|
images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks |
|
images_gray_restored = rgb_to_grayscale(images_gray_restored) |
|
|
|
if output_type=='pil': |
|
images_gray_restored = self.image_processor.postprocess(images_gray_restored) |
|
elif output_type=='np': |
|
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np') |
|
elif output_type=='pt': |
|
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt') |
|
elif output_type=='none': |
|
images_gray_restored = images_gray_restored |
|
else: |
|
raise ValueError('unsupported output_type') |
|
|
|
return images_gray_restored |
|
|
|
|
|
|