360PanoImage / txt2panoimg /text_to_360panorama_image_pipeline.py
gokaygokay's picture
Update txt2panoimg/text_to_360panorama_image_pipeline.py
9e8b7a5 verified
raw
history blame
8.49 kB
# Copyright © Alibaba, Inc. and its affiliates.
import random
from typing import Any, Dict
import numpy as np
import torch
from diffusers import (ControlNetModel, DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UniPCMultistepScheduler)
from PIL import Image
from RealESRGAN import RealESRGAN
from .pipeline_base import StableDiffusionBlendExtendPipeline
from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
class LazyRealESRGAN:
def __init__(self, device, scale):
self.device = device
self.scale = scale
self.model = None
self.model_path = None
def load_model(self):
if self.model is None:
self.model = RealESRGAN(self.device, scale=self.scale)
self.model.load_weights(self.model_path, download=False)
def predict(self, img):
self.load_model()
return self.model.predict(img)
class Text2360PanoramaImagePipeline(DiffusionPipeline):
""" Stable Diffusion for 360 Panorama Image Generation Pipeline.
Example:
>>> import torch
>>> from txt2panoimg import Text2360PanoramaImagePipeline
>>> prompt = 'The mountains'
>>> input = {'prompt': prompt, 'upscale': True}
>>> model_id = 'models/'
>>> txt2panoimg = Text2360PanoramaImagePipeline(model_id, torch_dtype=torch.float16)
>>> output = txt2panoimg(input)
>>> output.save('result.png')
"""
def __init__(self, model: str, device: str = 'cuda', **kwargs):
"""
Use `model` to create a stable diffusion pipeline for 360 panorama image generation.
Args:
model: model id on modelscope hub.
device: str = 'cuda'
"""
super().__init__()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
) if device is None else device
if device == 'gpu':
device = torch.device('cuda')
torch_dtype = kwargs.get('torch_dtype', torch.float16)
enable_xformers_memory_efficient_attention = kwargs.get(
'enable_xformers_memory_efficient_attention', True)
model_id = model + '/sd-base/'
# init base model
self.pipe = StableDiffusionBlendExtendPipeline.from_pretrained(
model_id, torch_dtype=torch_dtype).to(device)
self.pipe.vae.enable_tiling()
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.pipe.scheduler.config)
# remove following line if xformers is not installed
try:
if enable_xformers_memory_efficient_attention:
self.pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
print(e)
# init controlnet-sr model
base_model_path = model + '/sr-base'
controlnet_path = model + '/sr-control'
controlnet = ControlNetModel.from_pretrained(
controlnet_path, torch_dtype=torch_dtype)
self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(
base_model_path, controlnet=controlnet,
torch_dtype=torch_dtype).to(device)
self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config(
self.pipe.scheduler.config)
self.pipe_sr.vae.enable_tiling()
# remove following line if xformers is not installed
try:
if enable_xformers_memory_efficient_attention:
self.pipe_sr.enable_xformers_memory_efficient_attention()
except Exception as e:
print(e)
device = torch.device("cuda")
model_path = model + '/RealESRGAN_x2plus.pth'
self.upsampler = LazyRealESRGAN(device=device, scale=2)
self.upsampler.model_path = model_path
@staticmethod
def blend_h(a, b, blend_extent):
a = np.array(a)
b = np.array(b)
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
for x in range(blend_extent):
b[:, x, :] = a[:, -blend_extent
+ x, :] * (1 - x / blend_extent) + b[:, x, :] * (
x / blend_extent)
return b
def __call__(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if not isinstance(inputs, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
num_inference_steps = inputs.get('num_inference_steps', 20)
guidance_scale = inputs.get('guidance_scale', 7.5)
preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))'
add_prompt = inputs.get('add_prompt', preset_a_prompt)
preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\
'low quality, zombie, logo, text, watermark, username, monochrome, '\
'complex lighting'
negative_prompt = inputs.get('negative_prompt', preset_n_prompt)
seed = inputs.get('seed', -1)
upscale = inputs.get('upscale', True)
refinement = inputs.get('refinement', True)
guidance_scale_sr_step1 = inputs.get('guidance_scale_sr_step1', 15)
guidance_scale_sr_step2 = inputs.get('guidance_scale_sr_step1', 17)
if 'prompt' in inputs.keys():
prompt = inputs['prompt']
else:
# for demo_service
prompt = forward_params.get('prompt', 'the living room')
print(f'Test with prompt: {prompt}')
if seed == -1:
seed = random.randint(0, 65535)
print(f'global seed: {seed}')
generator = torch.manual_seed(seed)
prompt = '<360panorama>, ' + prompt + ', ' + add_prompt
output_img = self.pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=512,
width=1024,
guidance_scale=guidance_scale,
generator=generator).images[0]
if not upscale:
print('finished')
else:
print('inputs: upscale=True, running upscaler.')
print('running upscaler step1. Initial super-resolution')
sr_scale = 2.0
output_img = self.pipe_sr(
prompt.replace('<360panorama>, ', ''),
negative_prompt=negative_prompt,
image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
num_inference_steps=7,
generator=generator,
control_image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
strength=0.8,
controlnet_conditioning_scale=1.0,
guidance_scale=guidance_scale_sr_step1,
).images[0]
print('running upscaler step2. Super-resolution with Real-ESRGAN')
output_img = output_img.resize((1536 * 2, 768 * 2))
w = output_img.size[0]
blend_extend = 10
outscale = 2
output_img = np.array(output_img)
output_img = np.concatenate(
[output_img, output_img[:, :blend_extend, :]], axis=1)
output_img = self.upsampler.predict(
output_img)
output_img = self.blend_h(output_img, output_img,
blend_extend * outscale)
output_img = Image.fromarray(output_img[:, :w * outscale, :])
if refinement:
print(
'inputs: refinement=True, running refinement. This is a bit time-consuming.'
)
sr_scale = 4
output_img = self.pipe_sr(
prompt.replace('<360panorama>, ', ''),
negative_prompt=negative_prompt,
image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
num_inference_steps=7,
generator=generator,
control_image=output_img.resize(
(int(1536 * sr_scale), int(768 * sr_scale))),
strength=0.8,
controlnet_conditioning_scale=1.0,
guidance_scale=guidance_scale_sr_step2,
).images[0]
print('finished')
return output_img