Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import glob | |
import math | |
import time | |
import argparse | |
import numpy as np | |
from PIL import Image | |
import safetensors.torch | |
import torch | |
from torchvision import transforms | |
import torchvision.transforms.functional as F | |
from accelerate import Accelerator | |
from accelerate.utils import set_seed | |
from diffusers import ( | |
AutoencoderKL, | |
UniPCMultistepScheduler, | |
DPMSolverMultistepScheduler, | |
DDPMScheduler, | |
UNet2DConditionModel, | |
) | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor | |
from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline | |
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix | |
from models.controlnet import ControlNetModel | |
def load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention): | |
scheduler_mapping = { | |
'unipcmultistep': UniPCMultistepScheduler, | |
'ddpm': DDPMScheduler, | |
'dpmmultistep': DPMSolverMultistepScheduler, | |
} | |
try: | |
scheduler_cls = scheduler_mapping[args.sample_method] | |
except KeyError: | |
raise ValueError(f"Invalid sample_method: {args.sample_method}") | |
scheduler = scheduler_cls.from_pretrained(args.pretrained_model_path, subfolder="scheduler") | |
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") | |
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
feature_extractor = CLIPImageProcessor.from_pretrained(os.path.join(args.pretrained_model_path, "feature_extractor")) | |
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet") | |
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_path, subfolder="controlnet") | |
vae_path = args.vae_model_path if args.vae_model_path else args.pretrained_model_path | |
vae = AutoencoderKL.from_pretrained(vae_path, subfolder="vae") | |
# Freeze models | |
for model in [vae, text_encoder, unet, controlnet]: | |
model.requires_grad_(False) | |
# Enable xformers if available | |
if enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
controlnet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Ensure it is installed correctly.") | |
# Initialize pipeline | |
validation_pipeline = StableDiffusionControlNetPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
feature_extractor=feature_extractor, | |
unet=unet, | |
controlnet=controlnet, | |
scheduler=scheduler, | |
safety_checker=None, | |
requires_safety_checker=False, | |
) | |
if args.tile_vae: | |
validation_pipeline._init_tiled_vae( | |
encoder_tile_size=args.vae_encoder_tile_size, | |
decoder_tile_size=args.vae_decoder_tile_size | |
) | |
# Set weight dtype based on mixed precision | |
dtype_mapping = { | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
} | |
weight_dtype = dtype_mapping.get(accelerator.mixed_precision, torch.float32) | |
# Move models to accelerator device with appropriate dtype | |
for model in [text_encoder, vae, unet, controlnet]: | |
model.to(accelerator.device, dtype=weight_dtype) | |
return validation_pipeline | |
def main(args, enable_xformers_memory_efficient_attention=True,): | |
detailed_output_dir = os.path.join( | |
args.output_dir, | |
f"sr_{args.baseline_name}_{args.sample_method}_{str(args.num_inference_steps).zfill(3)}steps_{args.start_point}{args.start_steps}_size{args.process_size}_cfg{args.guidance_scale}" | |
) | |
accelerator = Accelerator( | |
mixed_precision=args.mixed_precision, | |
) | |
# If passed along, set the training seed now. | |
if args.seed is not None: | |
set_seed(args.seed) | |
# Handle the output folder creation | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
os.makedirs(detailed_output_dir, exist_ok=True) | |
accelerator.init_trackers("Controlnet") | |
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention) | |
if accelerator.is_main_process: | |
generator = torch.Generator(device=accelerator.device) | |
if args.seed is not None: | |
generator.manual_seed(args.seed) | |
image_paths = sorted(glob.glob(os.path.join(args.image_path, "*.*"))) if os.path.isdir(args.image_path) else [args.image_path] | |
time_records = [] | |
for image_path in image_paths: | |
validation_image = Image.open(image_path).convert("RGB") | |
negative_prompt = args.negative_prompt | |
validation_prompt = args.added_prompt | |
ori_width, ori_height = validation_image.size | |
resize_flag = False | |
rscale = args.upscale | |
if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale: | |
scale = (args.process_size//rscale)/min(ori_width, ori_height) | |
tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height))) | |
validation_image = tmp_image | |
resize_flag = True | |
validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale)) | |
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8)) | |
width, height = validation_image.size | |
resize_flag = True # | |
for sample_idx in range(args.sample_times): | |
os.makedirs(f'{detailed_output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True) | |
for sample_idx in range(args.sample_times): | |
inference_time, image = pipeline( | |
args.t_max, | |
args.t_min, | |
args.tile_diffusion, | |
args.tile_diffusion_size, | |
args.tile_diffusion_stride, | |
args.added_prompt, | |
validation_image, | |
num_inference_steps=args.num_inference_steps, | |
generator=generator, | |
height=height, | |
width=width, | |
guidance_scale=args.guidance_scale, | |
negative_prompt=args.negative_prompt, | |
conditioning_scale=args.conditioning_scale, | |
start_steps=args.start_steps, | |
start_point=args.start_point, | |
use_vae_encode_condition=args.use_vae_encode_condition, | |
) | |
image = image.images[0] | |
print(f"Inference time: {inference_time:.4f} seconds") | |
time_records.append(inference_time) | |
# Apply color fixing if specified | |
if args.align_method != 'nofix': | |
fix_func = wavelet_color_fix if args.align_method == 'wavelet' else adain_color_fix | |
image = fix_func(image, validation_image) | |
if resize_flag: | |
image = image.resize((ori_width*rscale, ori_height*rscale)) | |
image_tensor = torch.clamp(F.to_tensor(image), 0, 1) | |
final_image = transforms.ToPILImage()(image_tensor) | |
base_name = os.path.splitext(os.path.basename(image_path))[0] | |
save_path = os.path.join(detailed_output_dir, f"sample{str(sample_idx).zfill(2)}", f"{base_name}.png") | |
image.save(save_path) | |
# Calculate the average inference time, excluding the first few for stabilization | |
if len(time_records) > 3: | |
average_time = np.mean(time_records[3:]) | |
else: | |
average_time = np.mean(time_records) | |
if accelerator.is_main_process: | |
print(f"Average inference time: {average_time:.4f} seconds") | |
# Save the run settings to a file | |
settings_path = os.path.join(detailed_output_dir, "settings.txt") | |
with open(settings_path, 'w') as f: | |
f.write("------------------ start ------------------\n") | |
for key, value in vars(args).items(): | |
f.write(f"{key} : {value}\n") | |
f.write("------------------- end -------------------\n") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Stable Diffusion ControlNet Pipeline for Super-Resolution") | |
parser.add_argument("--controlnet_model_path", type=str, default="", help="Path to ControlNet model") | |
parser.add_argument("--pretrained_model_path", type=str, default="", help="Path to pretrained model") | |
parser.add_argument("--vae_model_path", type=str, default="", help="Path to VAE model") | |
parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k", help="Additional prompt for generation") | |
parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", help="Negative prompt to avoid certain features") | |
parser.add_argument("--image_path", type=str, default="", help="Path to input image or directory") | |
parser.add_argument("--output_dir", type=str, default="", help="Directory to save outputs") | |
parser.add_argument("--mixed_precision", type=str, choices=["no", "fp16", "bf16"], default="fp16", help="Mixed precision mode") | |
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Guidance scale for generation") | |
parser.add_argument("--conditioning_scale", type=float, default=1.0, help="Conditioning scale") | |
parser.add_argument("--num_inference_steps", type=int, default=1, help="Number of inference steps(not the final inference time)") | |
# final_inference_time = num_inference_steps * (t_max - t_min) + 1 | |
parser.add_argument("--t_max", type=float, default=0.6666, help="Maximum timestep") | |
parser.add_argument("--t_min", type=float, default=0.0, help="Minimum timestep") | |
parser.add_argument("--process_size", type=int, default=512, help="Processing size of the image") | |
parser.add_argument("--upscale", type=int, default=1, help="Upscaling factor") | |
parser.add_argument("--seed", type=int, default=None, help="Random seed") | |
parser.add_argument("--sample_times", type=int, default=5, help="Number of samples to generate per image") | |
parser.add_argument("--sample_method", type=str, choices=['unipcmultistep', 'ddpm', 'dpmmultistep'], default='ddpm', help="Sampling method") | |
parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain', help="Alignment method for color fixing") | |
parser.add_argument("--start_steps", type=int, default=999, help="Starting steps") | |
parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr', help="Starting point for generation") | |
parser.add_argument("--baseline_name", type=str, default='ccsr-v2', help="Baseline name for output naming") | |
parser.add_argument("--use_vae_encode_condition", action='store_true', help="Use VAE encoding LQ condition") | |
# Tiling settings for high-resolution SR | |
parser.add_argument("--tile_diffusion", action="store_true", help="Optionally! Enable tile-based diffusion") | |
parser.add_argument("--tile_diffusion_size", type=int, default=512, help="Tile size for diffusion") | |
parser.add_argument("--tile_diffusion_stride", type=int, default=256, help="Stride size for diffusion tiles") | |
parser.add_argument("--tile_vae", action="store_true", help="Optionally! Enable tiling for VAE") | |
parser.add_argument("--vae_decoder_tile_size", type=int, default=224, help="Tile size for VAE decoder") | |
parser.add_argument("--vae_encoder_tile_size", type=int, default=1024, help="Tile size for VAE encoder") | |
args = parser.parse_args() | |
main(args) |