# Copyright © Alibaba, Inc. and its affiliates. import random from typing import Any, Dict import numpy as np import torch from diffusers import (ControlNetModel, DiffusionPipeline, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler) from PIL import Image from RealESRGAN import RealESRGAN from .pipeline_i2p import StableDiffusionImage2PanoPipeline from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline import py360convert class LazyRealESRGAN: def __init__(self, device, scale): self.device = device self.scale = scale self.model = None self.model_path = None def load_model(self): if self.model is None: self.model = RealESRGAN(self.device, scale=self.scale) self.model.load_weights(self.model_path, download=False) def predict(self, img): self.load_model() return self.model.predict(img) class Image2360PanoramaImagePipeline(DiffusionPipeline): """ Stable Diffusion for 360 Panorama Image Generation Pipeline. Example: >>> import torch >>> from txt2panoimg import Text2360PanoramaImagePipeline >>> prompt = 'The mountains' >>> input = {'prompt': prompt, 'upscale': True} >>> model_id = 'models/' >>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16) >>> output = txt2panoimg(input) >>> output.save('result.png') """ def __init__(self, model: str, device: str = 'cuda', **kwargs): """ Use `model` to create a stable diffusion pipeline for 360 panorama image generation. Args: model: model id on modelscope hub. device: str = 'cuda' """ super().__init__() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) if device is None else device if device == 'gpu': device = torch.device('cuda') torch_dtype = kwargs.get('torch_dtype', torch.float16) enable_xformers_memory_efficient_attention = kwargs.get( 'enable_xformers_memory_efficient_attention', True) model_id = model + '/sr-base/' # init i2p model controlnet = ControlNetModel.from_pretrained(model + '/sd-i2p', torch_dtype=torch.float16) self.pipe = StableDiffusionImage2PanoPipeline.from_pretrained( model_id, controlnet=controlnet, torch_dtype=torch_dtype).to(device) self.pipe.vae.enable_tiling() self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( self.pipe.scheduler.config) # remove following line if xformers is not installed try: if enable_xformers_memory_efficient_attention: self.pipe.enable_xformers_memory_efficient_attention() except Exception as e: print(e) # init controlnet-sr model base_model_path = model + '/sr-base' controlnet_path = model + '/sr-control' controlnet = ControlNetModel.from_pretrained( controlnet_path, torch_dtype=torch_dtype) self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained( base_model_path, controlnet=controlnet, torch_dtype=torch_dtype).to(device) self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config( self.pipe.scheduler.config) self.pipe_sr.vae.enable_tiling() # remove following line if xformers is not installed try: if enable_xformers_memory_efficient_attention: self.pipe_sr.enable_xformers_memory_efficient_attention() except Exception as e: print(e) device = torch.device("cuda") model_path = model + '/RealESRGAN_x2plus.pth' self.upsampler = LazyRealESRGAN(device=device, scale=2) self.upsampler.model_path = model_path @staticmethod def process_control_image(image, mask): def to_tensor(img: Image, batch_size: int, width=1024, height=512): img = img.resize((width, height), resample=Image.BICUBIC) img = np.array(img).astype(np.float32) / 255.0 img = np.vstack([img[None].transpose(0, 3, 1, 2)] * batch_size) img = torch.from_numpy(img) return img zeros = np.zeros_like(np.array(image)) dice_np = [np.array(image) if x == 0 else zeros for x in range(6)] output_image = py360convert.c2e(dice_np, 512, 1024, cube_format='list') bk_image = to_tensor(image, batch_size=1) control_image = Image.fromarray(output_image.astype(np.uint8)) control_image = to_tensor(control_image, batch_size=1) mask_image = to_tensor(mask, batch_size=1) control_image = (1 - mask_image) * bk_image + mask_image * control_image control_image = torch.cat([mask_image[:, :1, :, :], control_image], dim=1) return control_image @staticmethod def blend_h(a, b, blend_extent): a = np.array(a) b = np.array(b) blend_extent = min(a.shape[1], b.shape[1], blend_extent) for x in range(blend_extent): b[:, x, :] = a[:, -blend_extent + x, :] * (1 - x / blend_extent) + b[:, x, :] * ( x / blend_extent) return b def __call__(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: if not isinstance(inputs, dict): raise ValueError( f'Expected the input to be a dictionary, but got {type(input)}' ) num_inference_steps = inputs.get('num_inference_steps', 20) guidance_scale = inputs.get('guidance_scale', 7.0) preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))' add_prompt = inputs.get('add_prompt', preset_a_prompt) preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\ 'low quality, zombie, logo, text, watermark, username, monochrome, '\ 'complex lighting' negative_prompt = inputs.get('negative_prompt', preset_n_prompt) seed = inputs.get('seed', -1) upscale = inputs.get('upscale', True) refinement = inputs.get('refinement', True) guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15) guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17) image = inputs['image'] mask = inputs['mask'] control_image = self.process_control_image(image, mask) if 'prompt' in inputs.keys(): prompt = inputs['prompt'] else: # for demo_service prompt = forward_params.get('prompt', 'the living room') print(f'Test with prompt: {prompt}') if seed == -1: seed = random.randint(0, 65535) print(f'global seed: {seed}') generator = torch.manual_seed(seed) prompt = '<360panorama>, ' + prompt + ', ' + add_prompt output_img = self.pipe( prompt, image=(control_image[:, 1:, :, :] / 0.5 - 1.0), control_image=control_image, controlnet_conditioning_scale=1.0, strength=1.0, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, height=512, width=1024, guidance_scale=guidance_scale, generator=generator).images[0] if not upscale: print('finished') else: print('inputs: upscale=True, running upscaler.') print('running upscaler step1. Initial super-resolution') sr_scale = 2.0 output_img = self.pipe_sr( prompt.replace('<360panorama>, ', ''), negative_prompt=negative_prompt, image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), num_inference_steps=7, generator=generator, control_image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), strength=0.8, controlnet_conditioning_scale=1.0, guidance_scale=guidance_scale_sr_step1, ).images[0] print('running upscaler step2. Super-resolution with Real-ESRGAN') output_img = output_img.resize((1536 * 2, 768 * 2)) w = output_img.size[0] blend_extend = 10 outscale = 2 output_img = np.array(output_img) output_img = np.concatenate( [output_img, output_img[:, :blend_extend, :]], axis=1) output_img = self.upsampler.predict( output_img) output_img = self.blend_h(output_img, output_img, blend_extend * outscale) output_img = Image.fromarray(output_img[:, :w * outscale, :]) if refinement: print( 'inputs: refinement=True, running refinement. This is a bit time-consuming.' ) sr_scale = 4 output_img = self.pipe_sr( prompt.replace('<360panorama>, ', ''), negative_prompt=negative_prompt, image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), num_inference_steps=7, generator=generator, control_image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), strength=0.8, controlnet_conditioning_scale=1.0, guidance_scale=guidance_scale_sr_step2, ).images[0] print('finished') return output_img