File size: 14,642 Bytes
3790502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d21eb1b
3790502
 
 
d21eb1b
3790502
 
 
 
 
d21eb1b
 
3790502
 
 
 
 
 
 
 
 
d21eb1b
3790502
1241caa
5aa14ac
11a5d4d
18f8491
 
1241caa
 
 
3790502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from typing import Optional

from dataclasses import dataclass, field
from diffusers.models import AutoencoderKL, UNet2DConditionModel


import torch
from torch import nn


from dataclasses import dataclass



@dataclass
class BaseModelConfig:
    pass


from diffusers import AutoencoderKL, UNet2DConditionModel
from trainer.noise_schedulers.scheduling_ddpm_zerosnr import DDPMScheduler

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers.training_utils import EMAModel

from diffusers.utils import logging

from diffusers.utils.hub_utils import PushToHubMixin

from diffusers.models.modeling_utils import ModelMixin

from diffusers.configuration_utils import ConfigMixin

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# from hydra.utils import instantiate
from peft import get_peft_model

from layers import PositionalEncodingPermute1D
from einops import rearrange, repeat

from typing import Optional
from omegaconf import II


@dataclass
class LoraConfig:
    _target_: str = "peft.LoraConfig"
    r: int = 8
    lora_alpha: int =32
    target_modules: list = field(default_factory=lambda: ["to_q", "to_v", "query", "value"])
    lora_dropout: float =0.0
    bias: str ="none"


@dataclass
class SDModelConfig(BaseModelConfig):
    _target_: str = "sd_model.SDModel"
    pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
    conditioning_dropout_prob: float = 0.05
    use_ema: bool = True
    concat_all_steps: bool = False
    positional_encoding_type: Optional[str] = "sinusoidal"
    positional_encoding_length: Optional[int] = None
    image_positional_encoding_type: Optional[str] = None #"sinusoidal"
    image_positional_encoding_length: Optional[int] = None
    broadcast_positional_encoding: bool = True
    sequence_length: Optional[int] = 6
    text_sequence_length: Optional[int] = 7
    use_lora: bool = False
    # lora_cfg: Any = LoraConfig()
    zero_snr: bool = True
    # seed: int = 42 # TODO: inherit from higher config
    # lora: LoraConfig = LoraConfig(
    #     )
    
    
class SDModel(ModelMixin, ConfigMixin, PushToHubMixin):
    def __init__(self, cfg: SDModelConfig = None) -> None:
        super().__init__()
        
        if cfg is None: # workaround for default
            self.cfg = SDModelConfig()
        else:
            self.cfg = cfg

        print(self.cfg)
        
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            self.cfg.pretrained_model_name_or_path, 
            subfolder="scheduler",
            zero_snr=self.cfg.zero_snr)
        
        
        
        self.text_encoder = CLIPTextModel.from_pretrained(
            self.cfg.pretrained_model_name_or_path, subfolder="text_encoder", 
        )
        self.tokenizer = CLIPTokenizer.from_pretrained(
            self.cfg.pretrained_model_name_or_path, subfolder="tokenizer"
        )
        
        self.vae = AutoencoderKL.from_pretrained(self.cfg.pretrained_model_name_or_path, subfolder="vae")
        self.unet = UNet2DConditionModel.from_pretrained(
            self.cfg.pretrained_model_name_or_path, subfolder="unet"
        )
        
        in_channels = 8 # TODO make part of cfg
        out_channels = self.unet.conv_in.out_channels
        self.unet.register_to_config(in_channels=in_channels)

        with torch.no_grad(): 
            new_conv_in = nn.Conv2d(
                in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding
            )
            new_conv_in.weight.zero_()
            new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) # copy the pretrained weights, leave the rest as zero
            new_conv_in.bias.copy_(self.unet.conv_in.bias) # EXTREMELY IMPORTANT MODIFICATION FROM INITIAL DIFFUSERS CODE
            self.unet.conv_in = new_conv_in
            
        self.init_pos()
        self.init_image_pos()
        
        
        if self.cfg.use_lora:
            config = LoraConfig(
                    r=8,
                    lora_alpha=32,
                    target_modules=["to_q", "to_v", "query", "value"],
                    lora_dropout=0.0,
                    bias="none",
                )
            self.unet = get_peft_model(self.unet, config)
            self.unet.conv_in.requires_grad_(True)  # NOTE: this makes the whole input conv trainable, not just the new parameters! consider if that's what you really want
            self.unet.print_trainable_parameters()
            print(self.unet)
            
        self.vae.requires_grad_(False)
        self.text_encoder.requires_grad_(False)
        
        # use_ema = True
        # if use_ema:
        if self.cfg.use_ema:
            self.ema_unet = EMAModel(self.unet.parameters(), model_cls=UNet2DConditionModel, model_config=self.unet.config)
            
        self.generator = None 

    def init_pos(self):
        self.cfg.positional_encoding_length = self.cfg.text_sequence_length
        if not self.cfg.broadcast_positional_encoding:
            self.cfg.positional_encoding_length *= 77
        elif self.cfg.positional_encoding_type == 'sinusoidal':
            self.unet.pos = PositionalEncodingPermute1D(self.cfg.positional_encoding_length)
        elif self.cfg.positional_encoding_type is None or self.cfg.positional_encoding_type == 'None':
            self.unet.pos = nn.Identity()
        else:
            raise ValueError(f'Unknown positional encoding type {self.cfg.positional_encoding_type}')#torch.Generator(self.unet.device).manual_seed(42) # seed: int = 42 # TODO: inherit from higher config # device=self.unet.device
    
    def init_image_pos(self):
        self.cfg.image_positional_encoding_length = self.cfg.sequence_length
        if self.cfg.image_positional_encoding_type == 'sinusoidal':
            self.unet.image_pos = PositionalEncodingPermute1D(self.cfg.image_positional_encoding_length)
        elif self.cfg.image_positional_encoding_type is None:
            self.unet.image_pos = nn.Identity()
        else:
            raise ValueError(f'Unknown image positional encoding type {self.cfg.image_positional_encoding_type}')
        
    def tokenize_captions(self, captions):
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids
            
    def forward(self, batch): # replace with input_ids, edited_pixel_values, original_pixel_values
        batch_size = batch["input_ids"].shape[0]
        condition_image = batch["original_pixel_values"]
        input_ids = batch["input_ids"].to(self.text_encoder.device)
        # We want to learn the denoising process w.r.t the edited images which
        # are conditioned on the original image (which was edited) and the edit instruction.
        # So, first, convert images to latent space.
        edited_images = batch["edited_pixel_values"]#.to(self.cfg.weight_dtype) #TODO check dtype thing
        output_seq_length = edited_images.shape[1]
        # edited_images = edited_images.flatten(0,1)
        edited_images = rearrange(edited_images, 'b s c h w -> (b s) c h w')
        
        latents = self.vae.encode(edited_images).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor
        
        latents = rearrange(latents, '(b s) c h w -> b c (s h) w', s=output_seq_length)
        # latents = latents.unflatten(0,(batch_size,output_seq_length)).transpose(1,2).flatten(2,3) # TODO: change the (batch_size, 3) to (batch_size, output_seq_length)
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        
        if self.cfg.image_positional_encoding_type is not None:            
            latents = self.apply_image_positional_encoding(noisy_latents, output_seq_length)
        
        if len(input_ids.shape) == 2:
            input_ids = input_ids.unsqueeze(0)

        encoder_hidden_states = self.input_ids_to_text_condition(input_ids)
        if self.cfg.positional_encoding_type is not None:
            encoder_hidden_states = self.apply_step_positional_encoding(encoder_hidden_states)

        # Get the additional image embedding for conditioning.
        # Instead of getting a diagonal Gaussian here, we simply take the mode.
        original_image_embeds = self.vae.encode(condition_image).latent_dist.mode() #.to(self.cfg.weight_dtype)).latent_dist.mode() #TODO check dtype thing

        # Conditioning dropout to support classifier-free guidance during inference. For more details
        # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
        if self.cfg.conditioning_dropout_prob is not None:
            encoder_hidden_states, original_image_embeds = self.apply_conditioning_dropout(encoder_hidden_states, original_image_embeds)

        # original_image_embeds = original_image_embeds.repeat(1,1,2,1)
        # original_image_embeds = original_image_embeds.unsqueeze(2).expand(-1, -1, output_seq_length, -1, -1).reshape(batch_size, 4, 32*output_seq_length, 32)
        original_image_embeds = repeat(original_image_embeds, 'b c h w -> b c (s h) w', s=output_seq_length) # TODO unify with pipeline get_image_latents
        
        # Concatenate the `original_image_embeds` with the `noisy_latents`.
        concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)

        target = self.get_loss_target(latents, noise, timesteps)

        # Predict the noise residual and compute loss
        model_pred = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
        return model_pred, target

    def get_loss_target(self, latents, noise, timesteps):
        # Get the target for loss depending on the prediction type
        if self.noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif self.noise_scheduler.config.prediction_type == "v_prediction":
            target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
        return target

    def apply_conditioning_dropout(self, encoder_hidden_states, original_image_embeds):
        bsz = original_image_embeds.shape[0] # changed from the comment in line 141 from latents, but should be same. TODO check
        random_p = torch.rand(bsz, device=encoder_hidden_states.device, generator=self.generator) # was originally latents.device, TODO check
            # Sample masks for the edit prompts.
        prompt_mask = random_p < 2 * self.cfg.conditioning_dropout_prob
        prompt_mask = prompt_mask.reshape(bsz, 1, 1)
            # Final text conditioning.
        null_conditioning = self.get_null_conditioning()
        encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)

            # Sample masks for the original images.
        image_mask_dtype = original_image_embeds.dtype
        image_mask = 1 - (
                (random_p >= self.cfg.conditioning_dropout_prob).to(image_mask_dtype)
                * (random_p < 3 * self.cfg.conditioning_dropout_prob).to(image_mask_dtype)
            )
        image_mask = image_mask.reshape(bsz, 1, 1, 1)
        # Final image conditioning.
        original_image_embeds = image_mask * original_image_embeds
        return encoder_hidden_states,original_image_embeds

    def get_null_conditioning(self):
        null_token = self.tokenize_captions([""]).to(self.text_encoder.device)
        # null_conditioning = self.input_ids_to_text_condition(null_token) # would apply positional encoding twice
        null_conditioning = self.text_encoder(null_token)[0] # TODO fuse with input_ids_to_text_condition
        if not self.cfg.concat_all_steps:
            null_conditioning = repeat(null_conditioning, 'b t l -> b (s t) l', s=self.cfg.text_sequence_length)
        return null_conditioning

    def input_ids_to_text_condition(self, input_ids):
        # Get the text embedding for conditioning.
        if self.cfg.concat_all_steps:
            encoder_hidden_states = self.text_encoder(input_ids)[0] # text padded to 77 tokens; encoder_hidden_states.shape = (bsz, 77, 768)
        else:
            input_ids = rearrange(input_ids, 'b s t->(b s) t')
            encoder_hidden_states = self.text_encoder(input_ids)[0] # text padded to 77 tokens; encoder_hidden_states.shape = (bsz, 77, 768) # TODO check why this doesn't match concatenating the encodings of the three tokens; the ones that don't match are the 769-1535 dims of the feature, for tokens 15-76
            
            # if args.use_positional_encoding: # old way: added before concat which doesn't make sense
            #     encoder_hidden_states = pos(encoder_hidden_states) + encoder_hidden_states
            encoder_hidden_states = rearrange(encoder_hidden_states, '(b s) t d->b (s t) d', s=self.cfg.text_sequence_length)

        return encoder_hidden_states

    def apply_step_positional_encoding(self, encoder_hidden_states):
        positional_encoding = self.unet.pos(encoder_hidden_states)
        if self.cfg.broadcast_positional_encoding:
            positional_encoding = repeat(positional_encoding, 'b s d -> b (s t) d', t=77) # TODO check this
        encoder_hidden_states = positional_encoding + encoder_hidden_states
        return encoder_hidden_states
    
    def apply_image_positional_encoding(self, latents, output_seq_length):
        original_latents_shape = latents.shape
        h = original_latents_shape[2]//output_seq_length
        latents = rearrange(latents, 'b c (s h) w -> b s (c h w)', s=output_seq_length)
        image_pos = self.unet.image_pos(latents)
        latents = latents + image_pos
        latents = rearrange(latents, 'b s (c h w) -> b c (s h) w', s=output_seq_length, c=original_latents_shape[1], h=h, w=original_latents_shape[3]) # confirmed that without the pos addition in between, this reshaping brings it back to the original tensor
        return latents
    
    def instantiate_pipeline(self):
        pass