Spaces:
Running
Running
from typing import Any | |
from diffusers import LCMScheduler | |
import torch | |
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting | |
import numpy as np | |
from constants import DEVICE | |
from backend.models.lcmdiffusion_setting import LCMLora | |
from backend.device import is_openvino_device | |
from backend.openvino.pipelines import ( | |
get_ov_text_to_image_pipeline, | |
ov_load_taesd, | |
get_ov_image_to_image_pipeline, | |
) | |
from backend.pipelines.lcm import ( | |
get_lcm_model_pipeline, | |
load_taesd, | |
get_image_to_image_pipeline, | |
) | |
from backend.pipelines.lcm_lora import get_lcm_lora_pipeline | |
from backend.models.lcmdiffusion_setting import DiffusionTask | |
from image_ops import resize_pil_image | |
from math import ceil | |
class LCMTextToImage: | |
def __init__( | |
self, | |
device: str = "cpu", | |
) -> None: | |
self.pipeline = None | |
self.use_openvino = False | |
self.device = "" | |
self.previous_model_id = None | |
self.previous_use_tae_sd = False | |
self.previous_use_lcm_lora = False | |
self.previous_ov_model_id = "" | |
self.previous_safety_checker = False | |
self.previous_use_openvino = False | |
self.img_to_img_pipeline = None | |
self.is_openvino_init = False | |
self.task_type = DiffusionTask.text_to_image | |
self.torch_data_type = ( | |
torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16 | |
) | |
print(f"Torch datatype : {self.torch_data_type}") | |
def _pipeline_to_device(self): | |
print(f"Pipeline device : {DEVICE}") | |
print(f"Pipeline dtype : {self.torch_data_type}") | |
self.pipeline.to( | |
torch_device=DEVICE, | |
torch_dtype=self.torch_data_type, | |
) | |
def _add_freeu(self): | |
pipeline_class = self.pipeline.__class__.__name__ | |
if isinstance(self.pipeline.scheduler, LCMScheduler): | |
if pipeline_class == "StableDiffusionPipeline": | |
print("Add FreeU - SD") | |
self.pipeline.enable_freeu( | |
s1=0.9, | |
s2=0.2, | |
b1=1.2, | |
b2=1.4, | |
) | |
elif pipeline_class == "StableDiffusionXLPipeline": | |
print("Add FreeU - SDXL") | |
self.pipeline.enable_freeu( | |
s1=0.6, | |
s2=0.4, | |
b1=1.1, | |
b2=1.2, | |
) | |
def _update_lcm_scheduler_params(self): | |
if isinstance(self.pipeline.scheduler, LCMScheduler): | |
self.pipeline.scheduler = LCMScheduler.from_config( | |
self.pipeline.scheduler.config, | |
beta_start=0.001, | |
beta_end=0.01, | |
) | |
def init( | |
self, | |
device: str = "cpu", | |
lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(), | |
) -> None: | |
self.device = device | |
self.use_openvino = lcm_diffusion_setting.use_openvino | |
model_id = lcm_diffusion_setting.lcm_model_id | |
use_local_model = lcm_diffusion_setting.use_offline_model | |
use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder | |
use_lora = lcm_diffusion_setting.use_lcm_lora | |
lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora | |
ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id | |
if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value: | |
w, h = lcm_diffusion_setting.init_image.size | |
newW = lcm_diffusion_setting.image_width | |
newH = int(h * newW / w) | |
img = lcm_diffusion_setting.init_image.resize((newW, newH)) | |
print("新图",newH,newW, lcm_diffusion_setting.image_height) | |
lcm_diffusion_setting.init_image = resize_pil_image( | |
img, | |
lcm_diffusion_setting.image_width, | |
lcm_diffusion_setting.image_height, | |
) | |
print("图片大小",lcm_diffusion_setting.init_image) | |
if ( | |
self.pipeline is None | |
or self.previous_model_id != model_id | |
or self.previous_use_tae_sd != use_tiny_auto_encoder | |
or self.previous_lcm_lora_base_id != lcm_lora.base_model_id | |
or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id | |
or self.previous_use_lcm_lora != use_lora | |
or self.previous_ov_model_id != ov_model_id | |
or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker | |
or self.previous_use_openvino != lcm_diffusion_setting.use_openvino | |
or self.previous_task_type != lcm_diffusion_setting.diffusion_task | |
): | |
if self.use_openvino and is_openvino_device(): | |
if self.pipeline: | |
del self.pipeline | |
self.pipeline = None | |
self.is_openvino_init = True | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.text_to_image.value | |
): | |
print(f"***** Init Text to image (OpenVINO) - {ov_model_id} *****") | |
self.pipeline = get_ov_text_to_image_pipeline( | |
ov_model_id, | |
use_local_model, | |
) | |
elif ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
print(f"***** Image to image (OpenVINO) - {ov_model_id} *****") | |
self.pipeline = get_ov_image_to_image_pipeline( | |
ov_model_id, | |
use_local_model, | |
) | |
else: | |
if self.pipeline: | |
del self.pipeline | |
self.pipeline = None | |
if self.img_to_img_pipeline: | |
del self.img_to_img_pipeline | |
self.img_to_img_pipeline = None | |
if use_lora: | |
print( | |
f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****" | |
) | |
self.pipeline = get_lcm_lora_pipeline( | |
lcm_lora.base_model_id, | |
lcm_lora.lcm_lora_id, | |
use_local_model, | |
torch_data_type=self.torch_data_type, | |
) | |
else: | |
print(f"***** Init LCM Model pipeline - {model_id} *****") | |
self.pipeline = get_lcm_model_pipeline( | |
model_id, | |
use_local_model, | |
) | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
self.img_to_img_pipeline = get_image_to_image_pipeline( | |
self.pipeline | |
) | |
self._pipeline_to_device() | |
if use_tiny_auto_encoder: | |
if self.use_openvino and is_openvino_device(): | |
print("Using Tiny Auto Encoder (OpenVINO)") | |
ov_load_taesd( | |
self.pipeline, | |
use_local_model, | |
) | |
else: | |
print("Using Tiny Auto Encoder") | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.text_to_image.value | |
): | |
load_taesd( | |
self.pipeline, | |
use_local_model, | |
self.torch_data_type, | |
) | |
elif ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
load_taesd( | |
self.img_to_img_pipeline, | |
use_local_model, | |
self.torch_data_type, | |
) | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
and lcm_diffusion_setting.use_openvino | |
): | |
self.pipeline.scheduler = LCMScheduler.from_config( | |
self.pipeline.scheduler.config, | |
) | |
else: | |
self._update_lcm_scheduler_params() | |
if use_lora: | |
self._add_freeu() | |
self.previous_model_id = model_id | |
self.previous_ov_model_id = ov_model_id | |
self.previous_use_tae_sd = use_tiny_auto_encoder | |
self.previous_lcm_lora_base_id = lcm_lora.base_model_id | |
self.previous_lcm_lora_id = lcm_lora.lcm_lora_id | |
self.previous_use_lcm_lora = use_lora | |
self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker | |
self.previous_use_openvino = lcm_diffusion_setting.use_openvino | |
self.previous_task_type = lcm_diffusion_setting.diffusion_task | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.text_to_image.value | |
): | |
print(f"Pipeline : {self.pipeline}") | |
elif ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
if self.use_openvino and is_openvino_device(): | |
print(f"Pipeline : {self.pipeline}") | |
else: | |
print(f"Pipeline : {self.img_to_img_pipeline}") | |
def generate( | |
self, | |
lcm_diffusion_setting: LCMDiffusionSetting, | |
reshape: bool = False, | |
) -> Any: | |
guidance_scale = lcm_diffusion_setting.guidance_scale | |
img_to_img_inference_steps = lcm_diffusion_setting.inference_steps | |
check_step_value = int( | |
lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength | |
) | |
if ( | |
lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value | |
and check_step_value < 1 | |
): | |
img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength) | |
print( | |
f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}" | |
) | |
if lcm_diffusion_setting.use_seed: | |
cur_seed = lcm_diffusion_setting.seed | |
if self.use_openvino: | |
np.random.seed(cur_seed) | |
else: | |
torch.manual_seed(cur_seed) | |
is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device() | |
if is_openvino_pipe: | |
print("Using OpenVINO") | |
if reshape and not self.is_openvino_init: | |
print("Reshape and compile") | |
self.pipeline.reshape( | |
batch_size=-1, | |
height=lcm_diffusion_setting.image_height, | |
width=lcm_diffusion_setting.image_width, | |
num_images_per_prompt=lcm_diffusion_setting.number_of_images, | |
) | |
self.pipeline.compile() | |
if self.is_openvino_init: | |
self.is_openvino_init = False | |
if not lcm_diffusion_setting.use_safety_checker: | |
self.pipeline.safety_checker = None | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
and not is_openvino_pipe | |
): | |
self.img_to_img_pipeline.safety_checker = None | |
if ( | |
not lcm_diffusion_setting.use_lcm_lora | |
and not lcm_diffusion_setting.use_openvino | |
and lcm_diffusion_setting.guidance_scale != 1.0 | |
): | |
print("Not using LCM-LoRA so setting guidance_scale 1.0") | |
guidance_scale = 1.0 | |
if lcm_diffusion_setting.use_openvino: | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.text_to_image.value | |
): | |
result_images = self.pipeline( | |
prompt=lcm_diffusion_setting.prompt, | |
negative_prompt=lcm_diffusion_setting.negative_prompt, | |
num_inference_steps=lcm_diffusion_setting.inference_steps, | |
guidance_scale=guidance_scale, | |
width=lcm_diffusion_setting.image_width, | |
height=lcm_diffusion_setting.image_height, | |
num_images_per_prompt=lcm_diffusion_setting.number_of_images, | |
).images | |
elif ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
result_images = self.pipeline( | |
image=lcm_diffusion_setting.init_image, | |
strength=lcm_diffusion_setting.strength, | |
prompt=lcm_diffusion_setting.prompt, | |
negative_prompt=lcm_diffusion_setting.negative_prompt, | |
num_inference_steps=img_to_img_inference_steps * 3, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=lcm_diffusion_setting.number_of_images, | |
).images | |
else: | |
if ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.text_to_image.value | |
): | |
result_images = self.pipeline( | |
prompt=lcm_diffusion_setting.prompt, | |
negative_prompt=lcm_diffusion_setting.negative_prompt, | |
num_inference_steps=lcm_diffusion_setting.inference_steps, | |
guidance_scale=guidance_scale, | |
width=lcm_diffusion_setting.image_width, | |
height=lcm_diffusion_setting.image_height, | |
num_images_per_prompt=lcm_diffusion_setting.number_of_images, | |
).images | |
elif ( | |
lcm_diffusion_setting.diffusion_task | |
== DiffusionTask.image_to_image.value | |
): | |
result_images = self.img_to_img_pipeline( | |
image=lcm_diffusion_setting.init_image, | |
strength=lcm_diffusion_setting.strength, | |
prompt=lcm_diffusion_setting.prompt, | |
negative_prompt=lcm_diffusion_setting.negative_prompt, | |
num_inference_steps=img_to_img_inference_steps, | |
guidance_scale=guidance_scale, | |
width=lcm_diffusion_setting.image_width, | |
height=lcm_diffusion_setting.image_height, | |
num_images_per_prompt=lcm_diffusion_setting.number_of_images, | |
).images | |
return result_images | |