import torch import torch.nn as nn import torch.nn.functional as F from copy import deepcopy from torchvision.transforms.functional import rgb_to_grayscale import segmentation_models_pytorch as smp from diffusers import StableDiffusionInpaintPipeline from diffusers.utils.torch_utils import randn_tensor from transformers import PretrainedConfig, PreTrainedModel class SDGrayInpaintConfig(PretrainedConfig): model_type = "sd_gray_inpaint" def __init__( self, base_model="stabilityai/stable-diffusion-2-inpainting", height=512, width=512, **kwargs ): self.base_model=base_model self.height=height self.width=width super().__init__(**kwargs) class SDGrayInpaintModel(PreTrainedModel): config_class = SDGrayInpaintConfig def __init__(self, config): super().__init__(config) pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model) self.mask_predictor = smp.Unet( encoder_name="mit_b4", encoder_weights="imagenet", in_channels=3, classes=1, ) self.image_processor = pipe.image_processor self.scheduler = pipe.scheduler self.unet = pipe.unet self.vae = pipe.vae self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024)) self.height=config.height self.width=config.width def forward( self, images_gray_masked, masks=None, num_inference_steps=250, seed=42, input_type='pil', output_type='pil' ): generator = torch.Generator() generator.manual_seed(seed) if input_type=='pil': images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float() elif input_type=='pt': images_gray_masked=images_gray_masked else: raise ValueError('unsupported input_type') images_gray_masked = images_gray_masked.to(self.vae.device) if masks is None: masks_logits = self.mask_predictor(images_gray_masked) masks = (torch.sigmoid(masks_logits)>0.5)*1. masks = masks.float().to(self.vae.device) B, C, H, W = images_gray_masked.shape prompt_embeds = self.prompt_embeds.repeat(B,1,1) scheduler = deepcopy(self.scheduler) scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device) masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size)) latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma for t in scheduler.timesteps: latents = scheduler.scale_model_input(latents, t) latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1) noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0] latents = scheduler.step(noise_pred, t, latents)[0] latents = latents / self.vae.config.scaling_factor images_gray_restored = self.vae.decode(latents.detach())[0] images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks images_gray_restored = rgb_to_grayscale(images_gray_restored) if output_type=='pil': images_gray_restored = self.image_processor.postprocess(images_gray_restored) elif output_type=='np': images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np') elif output_type=='pt': images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt') elif output_type=='none': images_gray_restored = images_gray_restored else: raise ValueError('unsupported output_type') return images_gray_restored