# Copyright © Alibaba, Inc. and its affiliates. import random from typing import Any, Dict import numpy as np import torch from diffusers import (ControlNetModel, DiffusionPipeline, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler) from PIL import Image from RealESRGAN import RealESRGAN from .pipeline_base import StableDiffusionBlendExtendPipeline from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline class LazyRealESRGAN: def __init__(self, device, scale): self.device = device self.scale = scale self.model = None self.model_path = None def load_model(self): if self.model is None: self.model = RealESRGAN(self.device, scale=self.scale) self.model.load_weights(self.model_path, download=False) def predict(self, img): self.load_model() return self.model.predict(img) class Text2360PanoramaImagePipeline(DiffusionPipeline): """ Stable Diffusion for 360 Panorama Image Generation Pipeline. Example: >>> import torch >>> from txt2panoimg import Text2360PanoramaImagePipeline >>> prompt = 'The mountains' >>> input = {'prompt': prompt, 'upscale': True} >>> model_id = 'models/' >>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16) >>> output = txt2panoimg(input) >>> output.save('result.png') """ def __init__(self, model: str, device: str = 'cuda', **kwargs): """ Use `model` to create a stable diffusion pipeline for 360 panorama image generation. Args: model: model id on modelscope hub. device: str = 'cuda' """ super().__init__() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) if device is None else device if device == 'gpu': device = torch.device('cuda') torch_dtype = kwargs.get('torch_dtype', torch.float16) enable_xformers_memory_efficient_attention = kwargs.get( 'enable_xformers_memory_efficient_attention', True) model_id = model + '/sd-base/' # init base model self.pipe = StableDiffusionBlendExtendPipeline.from_pretrained( model_id, torch_dtype=torch_dtype).to(device) self.pipe.vae.enable_tiling() self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( self.pipe.scheduler.config) # remove following line if xformers is not installed try: if enable_xformers_memory_efficient_attention: self.pipe.enable_xformers_memory_efficient_attention() except Exception as e: print(e) # init controlnet-sr model base_model_path = model + '/sr-base' controlnet_path = model + '/sr-control' controlnet = ControlNetModel.from_pretrained( controlnet_path, torch_dtype=torch_dtype) self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained( base_model_path, controlnet=controlnet, torch_dtype=torch_dtype).to(device) self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config( self.pipe.scheduler.config) self.pipe_sr.vae.enable_tiling() # remove following line if xformers is not installed try: if enable_xformers_memory_efficient_attention: self.pipe_sr.enable_xformers_memory_efficient_attention() except Exception as e: print(e) device = torch.device("cuda") model_path = model + '/RealESRGAN_x2plus.pth' self.upsampler = LazyRealESRGAN(device=device, scale=2) self.upsampler.model_path = model_path @staticmethod def blend_h(a, b, blend_extent): a = np.array(a) b = np.array(b) blend_extent = min(a.shape[1], b.shape[1], blend_extent) for x in range(blend_extent): b[:, x, :] = a[:, -blend_extent + x, :] * (1 - x / blend_extent) + b[:, x, :] * ( x / blend_extent) return b def __call__(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: if not isinstance(inputs, dict): raise ValueError( f'Expected the input to be a dictionary, but got {type(input)}' ) num_inference_steps = inputs.get('num_inference_steps', 20) guidance_scale = inputs.get('guidance_scale', 7.5) preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))' add_prompt = inputs.get('add_prompt', preset_a_prompt) preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\ 'low quality, zombie, logo, text, watermark, username, monochrome, '\ 'complex lighting' negative_prompt = inputs.get('negative_prompt', preset_n_prompt) seed = inputs.get('seed', -1) upscale = inputs.get('upscale', True) refinement = inputs.get('refinement', True) guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15) guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17) if 'prompt' in inputs.keys(): prompt = inputs['prompt'] else: # for demo_service prompt = forward_params.get('prompt', 'the living room') print(f'Test with prompt: {prompt}') if seed == -1: seed = random.randint(0, 65535) print(f'global seed: {seed}') generator = torch.manual_seed(seed) prompt = '<360panorama>, ' + prompt + ', ' + add_prompt output_img = self.pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, height=512, width=1024, guidance_scale=guidance_scale, generator=generator).images[0] if not upscale: print('finished') else: print('inputs: upscale=True, running upscaler.') print('running upscaler step1. Initial super-resolution') sr_scale = 2.0 output_img = self.pipe_sr( prompt.replace('<360panorama>, ', ''), negative_prompt=negative_prompt, image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), num_inference_steps=7, generator=generator, control_image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), strength=0.8, controlnet_conditioning_scale=1.0, guidance_scale=guidance_scale_sr_step1, ).images[0] print('running upscaler step2. Super-resolution with Real-ESRGAN') output_img = output_img.resize((1536 * 2, 768 * 2)) w = output_img.size[0] blend_extend = 10 outscale = 2 output_img = np.array(output_img) output_img = np.concatenate( [output_img, output_img[:, :blend_extend, :]], axis=1) output_img = self.upsampler.predict( output_img) output_img = self.blend_h(output_img, output_img, blend_extend * outscale) output_img = Image.fromarray(output_img[:, :w * outscale, :]) if refinement: print( 'inputs: refinement=True, running refinement. This is a bit time-consuming.' ) sr_scale = 4 output_img = self.pipe_sr( prompt.replace('<360panorama>, ', ''), negative_prompt=negative_prompt, image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), num_inference_steps=7, generator=generator, control_image=output_img.resize( (int(1536 * sr_scale), int(768 * sr_scale))), strength=0.8, controlnet_conditioning_scale=1.0, guidance_scale=guidance_scale_sr_step2, ).images[0] print('finished') return output_img