File size: 7,159 Bytes
fc843fe
 
 
 
 
 
 
 
 
 
 
ebdd902
fc843fe
8182d33
 
 
 
 
9c65818
fc843fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebdd902
8182d33
 
ebdd902
 
 
 
fc843fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebdd902
 
 
 
 
fc843fe
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Modified main.py on DiffMorpher, LCM-LoRA support + param additions + logging + optimizations for speed-up
"""
import os
import torch
import numpy as np
from PIL import Image
from argparse import ArgumentParser
from model import DiffMorpherPipeline
import time
import logging
import gc

# os.environ["HF_HOME"] = "/app/hf_cache"
# os.environ["DIFFUSERS_CACHE"] = "/app/hf_cache"
# os.environ["TORCH_HOME"] = "/app/torch_cache"
# os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
# os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache/datasets"

logs_folder = "logs"
os.makedirs(logs_folder, exist_ok=True)

# Create a unique log filename using the current time 
log_filename = os.path.join(logs_folder, f"execution_{time.strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
    filename=log_filename,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

start_time = time.time()

parser = ArgumentParser()
parser.add_argument(
    "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
    help="Pretrained model to use (default: %(default)s)"
)
# Available SDV1-5 versions: 
# sd-legacy/stable-diffusion-v1-5
# lykon/dreamshaper-7

# Original DiffMorpher SD: 
# stabilityai/stable-diffusion-2-1-base

# Quantized models to try (non-functional, possible extension for future)
# DarkFlameUniverse/Stable-Diffusion-2-1-Base-8bit
# Xerox32/SD2.1-base-Int8

parser.add_argument(
    "--image_path_0", type=str, default="",
    help="Path of the first image (default: %(default)s)"
)
parser.add_argument(
    "--prompt_0", type=str, default="",
    help="Prompt of the first image (default: %(default)s)"
)
parser.add_argument(
    "--image_path_1", type=str, default="",
    help="Path of the second image (default: %(default)s)"
)
parser.add_argument(
    "--prompt_1", type=str, default="",
    help="Prompt of the second image (default: %(default)s)"
)
parser.add_argument(
    "--output_path", type=str, default="./results",
    help="Path of the output image (default: %(default)s)"
)
parser.add_argument(
    "--save_lora_dir", type=str, default="./lora",
    help="Path for saving LoRA weights (default: %(default)s)"
)
parser.add_argument(
    "--load_lora_path_0", type=str, default="",
    help="Path of the LoRA weights for the first image (default: %(default)s)"
)
parser.add_argument(
    "--load_lora_path_1", type=str, default="",
    help="Path of the LoRA weights for the second image (default: %(default)s)"
)
parser.add_argument(
    "--num_inference_steps", type=int, default=50, 
    help="Number of inference steps (default: %(default)s)")
parser.add_argument(
    "--guidance_scale", type=float, default=1,  # To match current diffmorpher
    help="Guidance scale for classifier-free guidance (default: %(default)s)"
)

parser.add_argument("--use_adain", action="store_true", help="Use AdaIN (default: %(default)s)")
parser.add_argument("--use_reschedule",  action="store_true", help="Use reschedule sampling (default: %(default)s)")
parser.add_argument("--lamb",  type=float, default=0.6, help="Lambda for self-attention replacement (default: %(default)s)")
parser.add_argument("--fix_lora_value", type=float, default=None, help="Fix lora value (default: LoRA Interp., not fixed)")
parser.add_argument("--save_inter", action="store_true", help="Save intermediate results (default: %(default)s)")
parser.add_argument("--num_frames", type=int, default=16, help="Number of frames to generate (default: %(default)s)")
parser.add_argument("--duration", type=int, default=100, help="Duration of each frame (default: %(default)s ms)")
parser.add_argument("--no_lora", action="store_true", help="Disable style LoRA (default: %(default)s)")

# New argument for LCM LoRA acceleration
parser.add_argument("--use_lcm", action="store_true", help="Enable LCM-LoRA acceleration for faster sampling")

args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)

# Clear any existing PyTorch GPU allocations
# torch.cuda.empty_cache()
# gc.collect()

# Set environment variable for memory allocation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Create the pipeline from the given model path
pipeline = DiffMorpherPipeline.from_pretrained(args.model_path, torch_dtype=torch.float32)

# memory optimisations for vae and attention slicing - breaks computations into smaller chunks to fit better in mem
# can lead to more efficient caching and memory access. better memory locality
# found that its helpful with GPUs with limited VRAM memory in particular.
pipeline.enable_vae_slicing()
pipeline.enable_attention_slicing()

pipeline.to("cuda")

# Add these AFTER device movement
torch.backends.cudnn.benchmark = True # finds efficient convolution algo by running short benchmark, minimal speed-up.
torch.set_float32_matmul_precision("high")  # Better for modern GPUs, reduces about 7 seconds of inference time.

# Integrate LCM-LoRA if flagged, OUTSIDE any of the style LoRA loading / training steps.
if args.use_lcm:
    from lcm_lora.lcm_schedule import LCMScheduler
    # Replace scheduler using LCM's configuration
    pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
    # Load the LCM LoRA weights (LCM provides an add-on network)
    pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
    # Set the lcm_inference_steps
    args.num_inference_steps = 8  # Override with LCM-recommended steps
    # set CFG (range allowed by legacy code: 0 to 1, 1 performs best)
    args.guidance_scale = 1

# Run the pipeline inference using existing parameters
images = pipeline(
    img_path_0=args.image_path_0,
    img_path_1=args.image_path_1,
    prompt_0=args.prompt_0,
    prompt_1=args.prompt_1,
    save_lora_dir=args.save_lora_dir,
    load_lora_path_0=args.load_lora_path_0,
    load_lora_path_1=args.load_lora_path_1,
    use_adain=args.use_adain,
    use_reschedule=args.use_reschedule,
    lamd=args.lamb,
    output_path=args.output_path,
    num_frames=args.num_frames,
    num_inference_steps = args.num_inference_steps, # enforce when LCM enabled
    fix_lora=args.fix_lora_value,
    save_intermediates=args.save_inter,
    use_lora=not args.no_lora,
    use_lcm = args.use_lcm,
    guidance_scale=args.guidance_scale, # enforce when LCM enabled
)

# Save the resulting GIF output from the sequence of images
images[0].save(f"{args.output_path}/output.gif", save_all=True,
               append_images=images[1:], duration=args.duration, loop=0)

# Ensure memory is freed after completion
pipeline = None
torch.cuda.empty_cache()
gc.collect()

end_time = time.time()
elapsed_time = end_time - start_time

# Log the execution details and parameters
logging.info(f"Total execution time: {elapsed_time:.2f} seconds")
logging.info(f"Model Path: {args.model_path}")
logging.info(f"Image Path 0: {args.image_path_0}")
logging.info(f"Image Path 1: {args.image_path_1}")
logging.info(f"Use LCM: {args.use_lcm}")
logging.info(f"Number of inference steps: {args.num_inference_steps}")
logging.info(f"Guidance scale: {args.guidance_scale}")

print(f"Total execution time: {elapsed_time:.2f} seconds, log file saved as {log_filename}")