bokyeong1015's picture
Update demo.py with the released checkpoint
07df82c
raw
history blame
5.11 kB
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
import copy
import time
ORIGINAL_CHECKPOINT_ID = "CompVis/stable-diffusion-v1-4"
COMPRESSED_UNET_ID = "nota-ai/bk-sdm-small"
DEVICE='cuda'
# DEVICE='cpu'
class SdmCompressionDemo:
def __init__(self, device) -> None:
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
self.pipe_original = StableDiffusionPipeline.from_pretrained(ORIGINAL_CHECKPOINT_ID,
torch_dtype=self.torch_dtype)
self.pipe_compressed = copy.deepcopy(self.pipe_original)
self.pipe_compressed.unet = UNet2DConditionModel.from_pretrained(COMPRESSED_UNET_ID,
subfolder="unet",
torch_dtype=self.torch_dtype)
if 'cuda' in self.device:
self.pipe_original = self.pipe_original.to(self.device)
self.pipe_compressed = self.pipe_compressed.to(self.device)
self.device_msg = 'Tested on GPU.' if 'cuda' in self.device else 'Tested on CPU.'
def _count_params(self, model):
return sum(p.numel() for p in model.parameters())
def get_sdm_params(self, pipe):
params_unet = self._count_params(pipe.unet)
params_text_enc = self._count_params(pipe.text_encoder)
params_image_dec = self._count_params(pipe.vae.decoder)
params_total = params_unet + params_text_enc + params_image_dec
return f"Total {(params_total/1e6):.1f}M (U-Net {(params_unet/1e6):.1f}M)"
def generate_image(self, pipe, text, negative, guidance_scale, steps, seed):
generator = torch.Generator(self.device).manual_seed(seed)
start = time.time()
result = pipe(text, negative_prompt = negative, generator = generator,
guidance_scale = guidance_scale, num_inference_steps = steps)
test_time = time.time() - start
image = result.images[0]
nsfw_detected = result.nsfw_content_detected[0]
print(f"text {text} | Processed time: {test_time} sec | nsfw_flag {nsfw_detected}")
print(f"negative {negative} | guidance_scale {guidance_scale} | steps {steps} ")
print("===========")
return image, nsfw_detected, format(test_time, ".2f")
def error_msg(self, nsfw_detected):
if nsfw_detected:
return self.device_msg+" Black images are returned when potential harmful content is detected. Try different prompts or seeds."
else:
return self.device_msg
def check_invalid_input(self, text):
if text == '':
return True
def infer_original_model(self, text, negative, guidance_scale, steps, seed):
print(f"=== ORIG model --- seed {seed}")
if self.check_invalid_input(text):
return None, "Please enter the input prompt.", None
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_original,
text, negative, guidance_scale, steps, seed)
return output_image, self.error_msg(nsfw_detected), test_time
def infer_compressed_model(self, text, negative, guidance_scale, steps, seed):
print(f"=== COMPRESSED model --- seed {seed}")
if self.check_invalid_input(text):
return None, "Please enter the input prompt.", None
output_image, nsfw_detected, test_time = self.generate_image(self.pipe_compressed,
text, negative, guidance_scale, steps, seed)
return output_image, self.error_msg(nsfw_detected), test_time
def get_example_list(self):
return [
'a tropical bird sitting on a branch of a tree',
'many decorative umbrellas hanging up',
'an orange cat staring off with pretty eyes',
'beautiful woman face with fancy makeup',
'a decorated living room with a stylish feel',
'a black vase holding a bouquet of roses',
'very elegant bedroom featuring natural wood',
'buffet-style food including cake and cheese',
'a tall castle sitting under a cloudy sky',
'closeup of a brown bear sitting in a grassy area',
'a large basket with many fresh vegetables',
'house being built with lots of wood',
'a close up of a pizza with several toppings',
'a golden vase with many different flows',
'a statue of a lion face attached to brick wall',
'something that looks particularly interesting',
'table filled with a variety of different dishes',
'a cinematic view of a large snowy peak',
'a grand city in the year 2100, hyper realistic',
'a blue eyed baby girl looking at the camera',
]