|
from diffusers import StableDiffusionPipeline, UNet2DConditionModel |
|
import torch |
|
import copy |
|
|
|
import time |
|
|
|
ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4" |
|
COMPRESSED_UNET_ID = "nota-ai/bk-sdm-small" |
|
|
|
DEVICE='cuda' |
|
|
|
|
|
class SdmCompressionDemo: |
|
def __init__(self, device) -> None: |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32 |
|
|
|
self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID, |
|
torch_dtype=self.torch_dtype) |
|
self.pipe_compressed = copy.deepcopy(self.pipe_original) |
|
self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_ID, |
|
subfolder="unet", |
|
torch_dtype=self.torch_dtype) |
|
if 'cuda' in self.device: |
|
self.pipe_original = self.pipe_original.to(self.device) |
|
self.pipe_compressed = self.pipe_compressed.to(self.device) |
|
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.' |
|
|
|
def _count_params(self, model): |
|
return sum(p.numel() for p in model.parameters()) |
|
|
|
def get_sdm_params(self, pipe): |
|
params_unet = self._count_params(pipe.unet) |
|
params_text_enc = self._count_params(pipe.text_encoder) |
|
params_image_dec = self._count_params(pipe.vae.decoder) |
|
params_total = params_unet + params_text_enc + params_image_dec |
|
return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)" |
|
|
|
|
|
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed): |
|
generator = torch.Generator(self.device).manual_seed(seed) |
|
start = time.time() |
|
result = pipe(text, negative_prompt = negative, generator = generator, |
|
guidance_scale = guidance_scale, num_inference_steps = steps) |
|
test_time = time.time() - start |
|
|
|
image = result.images[0] |
|
nsfw_detected = result.nsfw_content_detected[0] |
|
print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}") |
|
print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ") |
|
print("===========") |
|
|
|
return image, nsfw_detected, format(test_time, ".2f") |
|
|
|
def error_msg(self, nsfw_detected): |
|
if nsfw_detected: |
|
return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds." |
|
else: |
|
return self.device_msg |
|
|
|
def check_invalid_input(self, text): |
|
if text == '': |
|
return True |
|
|
|
def infer_original_model(self, text, negative, guidance_scale, steps, seed): |
|
print(f"=== ORIG model --- seed {seed}") |
|
if self.check_invalid_input(text): |
|
return None, "Please enter the input prompt.", None |
|
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original, |
|
text, negative, guidance_scale, steps, seed) |
|
|
|
return output_image, self.error_msg(nsfw_detected), test_time |
|
|
|
def infer_compressed_model(self, text, negative, guidance_scale, steps, seed): |
|
print(f"=== COMPRESSED model --- seed {seed}") |
|
if self.check_invalid_input(text): |
|
return None, "Please enter the input prompt.", None |
|
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed, |
|
text, negative, guidance_scale, steps, seed) |
|
|
|
return output_image, self.error_msg(nsfw_detected), test_time |
|
|
|
|
|
def get_example_list(self): |
|
return [ |
|
'a tropical bird sitting on a branch of a tree', |
|
'many decorative umbrellas hanging up', |
|
'an orange cat staring off with pretty eyes', |
|
'beautiful woman face with fancy makeup', |
|
'a decorated living room with a stylish feel', |
|
'a black vase holding a bouquet of roses', |
|
'very elegant bedroom featuring natural wood', |
|
'buffet-style food including cake and cheese', |
|
'a tall castle sitting under a cloudy sky', |
|
'closeup of a brown bear sitting in a grassy area', |
|
'a large basket with many fresh vegetables', |
|
'house being built with lots of wood', |
|
'a close up of a pizza with several toppings', |
|
'a golden vase with many different flows', |
|
'a statue of a lion face attached to brick wall', |
|
'something that looks particularly interesting', |
|
'table filled with a variety of different dishes', |
|
'a cinematic view of a large snowy peak', |
|
'a grand city in the year 2100, hyper realistic', |
|
'a blue eyed baby girl looking at the camera', |
|
] |
|
|
|
|
|
|
|
|
|
|