from typing import Union, List import PIL import numpy as np import torch import torchvision.transforms as T from einops import repeat from kandinsky3.model.unet import UNet from kandinsky3.movq import MoVQ from kandinsky3.condition_encoders import T5TextConditionEncoder from kandinsky3.condition_processors import T5TextConditionProcessor from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion class Kandinsky3InpaintingPipeline: def __init__( self, device_map: Union[str, torch.device, dict], dtype_map: Union[str, torch.dtype, dict], unet: UNet, null_embedding: torch.Tensor, t5_processor: T5TextConditionProcessor, t5_encoder: T5TextConditionEncoder, movq: MoVQ, ): self.device_map = device_map self.dtype_map = dtype_map self.to_pil = T.ToPILImage() self.to_tensor = T.ToTensor() self.unet = unet self.null_embedding = null_embedding self.t5_processor = t5_processor self.t5_encoder = t5_encoder self.movq = movq def shared_step(self, batch: dict) -> dict: image = batch['image'] condition_model_input = batch['text'] negative_condition_model_input = batch['negative_text'] bs = image.shape[0] masked_latent = None mask = batch['mask'] if 'masked_image' in batch: masked_latent = batch['masked_image'] elif self.unet.in_layer.in_channels == 9: masked_latent = image.masked_fill((1 - mask).bool(), 0) else: raise ValueError() with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']): masked_latent = self.movq.encode(masked_latent) mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3])) with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']): context, context_mask = self.t5_encoder(condition_model_input) if negative_condition_model_input is not None: negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input) else: negative_context, negative_context_mask = None, None return { 'context': context, 'context_mask': context_mask, 'negative_context': negative_context, 'negative_context_mask': negative_context_mask, 'image': image, 'masked_latent': masked_latent, 'mask': mask } def prepare_batch( self, text: str, negative_text: str, image: PIL.Image.Image, mask: np.ndarray, ) -> dict: condition_model_input, negative_condition_model_input = self.t5_processor.encode( text=text, negative_text=negative_text ) batch = { 'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1, 'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)), 'text': condition_model_input, 'negative_text': negative_condition_model_input } batch['mask'] = batch['mask'].type(self.dtype_map['movq']) batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq']) batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder']) batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to( self.device_map['text_encoder']) batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq']) if negative_condition_model_input is not None: batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to( self.device_map['text_encoder']) batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to( self.device_map['text_encoder']) return batch def __call__( self, text: str, image: PIL.Image.Image, mask: np.ndarray, negative_text: str = None, images_num: int = 1, bs: int = 1, steps: int = 50, guidance_weight_text: float = 4, eta=1.0 ) -> List[PIL.Image.Image]: with torch.no_grad(): batch = self.prepare_batch(text, negative_text, image, mask) processed = self.shared_step(batch) betas = get_named_beta_schedule('cosine', 1000) base_diffusion = BaseDiffusion(betas, percentile=0.95) times = list(range(999, 0, -1000 // steps)) pil_images = [] k, m = images_num // bs, images_num % bs for minibatch in [bs] * k + [m]: if minibatch == 0: continue bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch) bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch) if processed['negative_context'] is not None: bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch) bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch) else: bs_negative_context, bs_negative_context_mask = None, None mask = processed['mask'].repeat_interleave(minibatch, dim=0) masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0) minibatch = masked_latent.shape[0] with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']): with torch.no_grad(): images = base_diffusion.p_sample_loop( self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times, self.device_map['unet'], bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta, negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask, mask=mask, masked_latent=masked_latent, gan=False ) with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']): images = torch.cat([self.movq.decode(image) for image in images.chunk(2)]) images = torch.clip((images + 1.) / 2., 0., 1.).cpu() for images_chunk in images.chunk(1): pil_images += [self.to_pil(image) for image in images_chunk] return pil_images