jwengr's picture
Upload folder using huggingface_hub
4651dee verified
raw
history blame
4.22 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from torchvision.transforms.functional import rgb_to_grayscale
import segmentation_models_pytorch as smp
from diffusers import StableDiffusionInpaintPipeline
from diffusers.utils.torch_utils import randn_tensor
from transformers import PretrainedConfig, PreTrainedModel
class SDGrayInpaintConfig(PretrainedConfig):
model_type = "sd_gray_inpaint"
def __init__(
self,
base_model="stabilityai/stable-diffusion-2-inpainting",
height=512,
width=512,
**kwargs
):
self.base_model=base_model
self.height=height
self.width=width
super().__init__(**kwargs)
class SDGrayInpaintModel(PreTrainedModel):
config_class = SDGrayInpaintConfig
def __init__(self, config):
super().__init__(config)
pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
self.mask_predictor = smp.Unet(
encoder_name="mit_b4",
encoder_weights="imagenet",
in_channels=3,
classes=1,
)
self.image_processor = pipe.image_processor
self.scheduler = pipe.scheduler
self.unet = pipe.unet
self.vae = pipe.vae
self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
self.height=config.height
self.width=config.width
def forward(
self,
images_gray_masked,
masks=None,
num_inference_steps=250,
seed=42,
input_type='pil',
output_type='pil'
):
generator = torch.Generator()
generator.manual_seed(seed)
if input_type=='pil':
images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float()
elif input_type=='pt':
images_gray_masked=images_gray_masked
else:
raise ValueError('unsupported input_type')
images_gray_masked = images_gray_masked.to(self.vae.device)
if masks is None:
masks_logits = self.mask_predictor(images_gray_masked)
masks = (torch.sigmoid(masks_logits)>0.5)*1.
masks = masks.float().to(self.vae.device)
images_gray_masked = (1-masks) * images_gray_masked
B, C, H, W = images_gray_masked.shape
prompt_embeds = self.prompt_embeds.repeat(B,1,1)
scheduler = deepcopy(self.scheduler)
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
for t in scheduler.timesteps:
latents = scheduler.scale_model_input(latents, t)
latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
latents = scheduler.step(noise_pred, t, latents)[0]
latents = latents / self.vae.config.scaling_factor
images_gray_restored = self.vae.decode(latents.detach())[0]
images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
images_gray_restored = rgb_to_grayscale(images_gray_restored)
if output_type=='pil':
images_gray_restored = self.image_processor.postprocess(images_gray_restored)
elif output_type=='np':
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
elif output_type=='pt':
images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
elif output_type=='none':
images_gray_restored = images_gray_restored
else:
raise ValueError('unsupported output_type')
return images_gray_restored