File size: 4,220 Bytes
f6018b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f9096c
f6018b4
 
 
 
 
 
 
 
 
4651dee
f6018b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy
from torchvision.transforms.functional import rgb_to_grayscale
import segmentation_models_pytorch as smp
from diffusers import StableDiffusionInpaintPipeline
from diffusers.utils.torch_utils import randn_tensor
from transformers import PretrainedConfig, PreTrainedModel

class SDGrayInpaintConfig(PretrainedConfig):
    model_type = "sd_gray_inpaint"
    def __init__(
        self,
        base_model="stabilityai/stable-diffusion-2-inpainting",
        height=512,
        width=512,
        **kwargs
    ):
        self.base_model=base_model
        self.height=height
        self.width=width
        super().__init__(**kwargs)

class SDGrayInpaintModel(PreTrainedModel):
    config_class = SDGrayInpaintConfig
    def __init__(self, config):
        super().__init__(config)
        pipe = StableDiffusionInpaintPipeline.from_pretrained(config.base_model)
        self.mask_predictor = smp.Unet(
            encoder_name="mit_b4",        
            encoder_weights="imagenet",     
            in_channels=3,                  
            classes=1,                      
        )
        self.image_processor = pipe.image_processor
        self.scheduler = pipe.scheduler
        self.unet = pipe.unet
        self.vae = pipe.vae
        self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024))
        self.height=config.height
        self.width=config.width

    def forward(
        self, 
        images_gray_masked, 
        masks=None, 
        num_inference_steps=250,
        seed=42,
        input_type='pil', 
        output_type='pil'
    ):
        generator = torch.Generator()
        generator.manual_seed(seed)
        if input_type=='pil':
            images_gray_masked = self.image_processor.preprocess(images_gray_masked, height=self.height, width=self.width).float()
        elif input_type=='pt':
            images_gray_masked=images_gray_masked
        else:
            raise ValueError('unsupported input_type')
        images_gray_masked = images_gray_masked.to(self.vae.device)
        if masks is None:
            masks_logits = self.mask_predictor(images_gray_masked)
            masks = (torch.sigmoid(masks_logits)>0.5)*1.
        masks = masks.float().to(self.vae.device)
        images_gray_masked = (1-masks) * images_gray_masked
        B, C, H, W = images_gray_masked.shape
        prompt_embeds = self.prompt_embeds.repeat(B,1,1)

        scheduler = deepcopy(self.scheduler)
        scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=self.vae.device)
        masked_image_latents = self.vae.encode(images_gray_masked).latent_dist.mode() * self.vae.config.scaling_factor
        mask_latents = F.interpolate(masks, size=(self.unet.config.sample_size, self.unet.config.sample_size))
        latents = randn_tensor(masked_image_latents.shape, generator=generator).to(self.device) * self.scheduler.init_noise_sigma
        for t in scheduler.timesteps:
            latents = scheduler.scale_model_input(latents, t)
            latent_model_input = torch.cat([latents, mask_latents, masked_image_latents], dim=1)
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[0]
            latents = scheduler.step(noise_pred, t, latents)[0]
        latents = latents / self.vae.config.scaling_factor
        images_gray_restored = self.vae.decode(latents.detach())[0]
        images_gray_restored = images_gray_masked * (1-masks) + images_gray_restored.detach() * masks
        images_gray_restored = rgb_to_grayscale(images_gray_restored)
        
        if output_type=='pil':
            images_gray_restored = self.image_processor.postprocess(images_gray_restored)
        elif output_type=='np':
            images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'np')
        elif output_type=='pt':
            images_gray_restored = self.image_processor.postprocess(images_gray_restored, 'pt')
        elif output_type=='none':
            images_gray_restored = images_gray_restored
        else:
            raise ValueError('unsupported output_type')

        return images_gray_restored