pta / scripts /controlnet_travel.py
ddoc's picture
Upload 36 files
d9f3559
# This extension works with [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
# version: v1.1.229
LOG_PREFIX = '[ControlNet-Travel]'
# ↓↓↓ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↓↓↓
CTRLNET_REPO_NAME = 'sdcontrol'
if 'externel repo sanity check':
from pathlib import Path
from modules.scripts import basedir
from traceback import print_exc
ME_PATH = Path(basedir())
CTRLNET_PATH = ME_PATH.parent / 'sdcontrol'
controlnet_found = False
try:
import sys ; sys.path.append(str(CTRLNET_PATH))
#from scripts.controlnet import Script as ControlNetScript # NOTE: this will mess up the import order
from scripts.external_code import ControlNetUnit
from scripts.hook import UNetModel, UnetHook, ControlParams
from scripts.hook import *
controlnet_found = True
print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} found, ControlNet-Travel loaded :)')
except ImportError:
print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} not found, ControlNet-Travel ignored :(')
exit(0)
except:
print_exc()
exit(0)
# ↑↑↑ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↑↑↑
import sys
from PIL import Image
from ldm.models.diffusion.ddpm import LatentDiffusion
from modules import shared, devices, lowvram
from modules.processing import StableDiffusionProcessing as Processing
from scripts.prompt_travel import *
from manager import run_cmd
class InterpMethod(Enum):
LINEAR = 'linear (weight sum)'
RIFE = 'rife (optical flow)'
if 'consts':
__ = lambda key, value=None: opts.data.get(f'customscript/controlnet_travel.py/txt2img/{key}/value', value)
LABEL_CTRLNET_REF_DIR = 'Reference image folder (one ref image per stage :)'
LABEL_INTERP_METH = 'Interpolate method'
LABEL_SKIP_FUSE = 'Ext. skip latent fusion'
LABEL_DEBUG_RIFE = 'Save RIFE intermediates'
DEFAULT_STEPS = 10
DEFAULT_CTRLNET_REF_DIR = str(ME_PATH / 'img' / 'ref_ctrlnet')
DEFAULT_INTERP_METH = __(LABEL_INTERP_METH, InterpMethod.LINEAR.value)
DEFAULT_SKIP_FUSE = __(LABEL_SKIP_FUSE, False)
DEFAULT_DEBUG_RIFE = __(LABEL_DEBUG_RIFE, False)
CHOICES_INTERP_METH = [x.value for x in InterpMethod]
if 'vars':
skip_fuse_plan: List[bool] = [] # n_blocks (13)
interp_alpha: float = 0.0
interp_ip: int = 0 # 0 ~ n_sampling_step-1
from_hint_cond: List[Tensor] = [] # n_contrlnet_set
to_hint_cond: List[Tensor] = []
mid_hint_cond: List[Tensor] = []
from_control_tensors: List[List[Tensor]] = [] # n_sampling_step x n_blocks
to_control_tensors: List[List[Tensor]] = []
caches: List[list] = [from_hint_cond, to_hint_cond, mid_hint_cond, from_control_tensors, to_control_tensors]
# ↓↓↓ the following is modified from 'sd-webui-controlnet/scripts/hook.py' ↓↓↓
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing):
self.model = model
self.sd_ldm = sd_ldm
self.control_params = control_params
outer = self
def process_sample(*args, **kwargs):
# ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
# You can use the hook.py's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
# Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
# if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
# if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
# After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
# After you mark the prompts, the mismatch errors will disappear.
mark_prompt_context(kwargs.get('conditioning', []), positive=True)
mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False)
mark_prompt_context(getattr(process, 'hr_c', []), positive=True)
mark_prompt_context(getattr(process, 'hr_uc', []), positive=False)
return process.sample_before_CN_hack(*args, **kwargs)
# NOTE: ↓↓↓ only hack this method ↓↓↓
def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
total_controlnet_embedding = [0.0] * 13
total_t2i_adapter_embedding = [0.0] * 4
require_inpaint_hijack = False
is_in_high_res_fix = False
batch_size = int(x.shape[0])
# NOTE: declare globals
global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, mid_hint_cond, interp_alpha, interp_ip
x: Tensor # [1, 4, 64, 64]
timesteps: Tensor # [1]
context: Tensor # [1, 78, 768]
kwargs: dict # {}
# Handle cond-uncond marker
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
# logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))
# High-res fix
for param in outer.control_params:
# select which hint_cond to use
if param.used_hint_cond is None:
param.used_hint_cond = param.hint_cond # NOTE: input hint cond tensor, [1, 3, 512, 512]
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
# has high-res fix
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
_, _, h_lr, w_lr = param.hint_cond.shape
_, _, h_hr, w_hr = param.hr_hint_cond.shape
_, _, h, w = x.shape
h, w = h * 8, w * 8
if abs(h - h_lr) < abs(h - h_hr):
is_in_high_res_fix = False
if param.used_hint_cond is not param.hint_cond:
param.used_hint_cond = param.hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
else:
is_in_high_res_fix = True
if param.used_hint_cond is not param.hr_hint_cond:
param.used_hint_cond = param.hr_hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
# NOTE: hint shallow fusion, overwrite param.used_hint_cond
for i, param in enumerate(outer.control_params):
if interp_alpha == 0.0: # collect hind_cond on key frames
if len(to_hint_cond) < len(outer.control_params):
to_hint_cond.append(param.used_hint_cond.clone().detach().cpu())
else: # interp with cached hind_cond
param.used_hint_cond = mid_hint_cond[i].to(x.device)
# Convert control image to latent
for param in outer.control_params:
if param.used_hint_cond_latent is not None:
continue
if param.control_model_type not in [ControlModelType.AttentionInjection] \
and 'colorfix' not in param.preprocessor['name'] \
and 'inpaint_only' not in param.preprocessor['name']:
continue
param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size)
# handle prompt token control
for param in outer.control_params:
if param.guidance_stopped:
continue
if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
continue
param.control_model.to(devices.get_device_for("controlnet"))
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
control = torch.cat([control.clone() for _ in range(batch_size)], dim=0)
control *= param.weight
control *= cond_mark[:, :, :, 0]
context = torch.cat([context, control.clone()], dim=1)
# handle ControlNet / T2I_Adapter
for param in outer.control_params:
if param.guidance_stopped:
continue
if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
continue
param.control_model.to(devices.get_device_for("controlnet"))
# inpaint model workaround
x_in = x
control_model = param.control_model.control_model
if param.control_model_type == ControlModelType.ControlNet:
if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9:
# inpaint_model: 4 data + 4 downscaled image + 1 mask
x_in = x[:, :4, ...]
require_inpaint_hijack = True
assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"
hint = param.used_hint_cond
# ControlNet inpaint protocol
if hint.shape[1] == 4:
c = hint[:, 0:3, :, :]
m = hint[:, 3:4, :, :]
m = (m > 0.5).float()
hint = c * (1 - m) - m
# NOTE: len(control) == 13, control[i]:Tensor
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
control_scales = ([param.weight] * 13)
if outer.lowvram:
param.control_model.to("cpu")
if param.cfg_injection or param.global_average_pooling:
if param.control_model_type == ControlModelType.T2I_Adapter:
control = [torch.cat([c.clone() for _ in range(batch_size)], dim=0) for c in control]
control = [c * cond_mark for c in control]
high_res_fix_forced_soft_injection = False
if is_in_high_res_fix:
if 'canny' in param.preprocessor['name']:
high_res_fix_forced_soft_injection = True
if 'mlsd' in param.preprocessor['name']:
high_res_fix_forced_soft_injection = True
# if high_res_fix_forced_soft_injection:
# logger.info('[ControlNet] Forced soft_injection in high_res_fix in enabled.')
if param.soft_injection or high_res_fix_forced_soft_injection:
# important! use the soft weights with high-res fix can significantly reduce artifacts.
if param.control_model_type == ControlModelType.T2I_Adapter:
control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)]
elif param.control_model_type == ControlModelType.ControlNet:
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
if param.advanced_weighting is not None:
control_scales = param.advanced_weighting
control = [c * scale for c, scale in zip(control, control_scales)]
if param.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
for idx, item in enumerate(control):
target = None
if param.control_model_type == ControlModelType.ControlNet:
target = total_controlnet_embedding
if param.control_model_type == ControlModelType.T2I_Adapter:
target = total_t2i_adapter_embedding
if target is not None:
target[idx] = item + target[idx]
# Replace x_t to support inpaint models
for param in outer.control_params:
if param.used_hint_cond.shape[1] != 4:
continue
if x.shape[1] != 9:
continue
if param.used_hint_inpaint_hijack is None:
mask_pixel = param.used_hint_cond[:, 3:4, :, :]
image_pixel = param.used_hint_cond[:, 0:3, :, :]
mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype)
masked_latent = outer.call_vae_using_process(process, image_pixel, batch_size, mask=mask_pixel)
mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8))
if mask_latent.shape[0] != batch_size:
mask_latent = torch.cat([mask_latent.clone() for _ in range(batch_size)], dim=0)
param.used_hint_inpaint_hijack = torch.cat([mask_latent, masked_latent], dim=1)
param.used_hint_inpaint_hijack.to(x.dtype).to(x.device)
x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)
# A1111 fix for medvram.
if shared.cmd_opts.medvram:
try:
# Trigger the register_forward_pre_hook
outer.sd_ldm.model()
except:
pass
# Clear attention and AdaIn cache
for module in outer.attn_module_list:
module.bank = []
module.style_cfgs = []
for module in outer.gn_module_list:
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
# Handle attention and AdaIn control
for param in outer.control_params:
if param.guidance_stopped:
continue
if param.used_hint_cond_latent is None:
continue
if param.control_model_type not in [ControlModelType.AttentionInjection]:
continue
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())
# Inpaint Hijack
if x.shape[1] == 9:
ref_xt = torch.cat([
ref_xt,
torch.zeros_like(ref_xt)[:, 0:1, :, :],
param.used_hint_cond_latent
], dim=1)
outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))
if param.cfg_injection:
outer.current_style_fidelity = 1.0
elif param.soft_injection or is_in_high_res_fix:
outer.current_style_fidelity = 0.0
control_name = param.preprocessor['name']
if control_name in ['reference_only', 'reference_adain+attn']:
outer.attention_auto_machine = AutoMachine.Write
outer.attention_auto_machine_weight = param.weight
if control_name in ['reference_adain', 'reference_adain+attn']:
outer.gn_auto_machine = AutoMachine.Write
outer.gn_auto_machine_weight = param.weight
outer.original_forward(
x=ref_xt.to(devices.dtype_unet),
timesteps=timesteps.to(devices.dtype_unet),
context=context.to(devices.dtype_unet)
)
outer.attention_auto_machine = AutoMachine.Read
outer.gn_auto_machine = AutoMachine.Read
# NOTE: hint latent fusion, overwrite control tensors
total_control = total_controlnet_embedding
if interp_alpha == 0.0: # collect control tensors on key frames
tensors: List[Tensor] = []
for i, t in enumerate(total_control):
if len(skip_fuse_plan) and skip_fuse_plan[i]:
tensors.append(None)
else:
tensors.append(t.clone().detach().cpu())
to_control_tensors.append(tensors)
else: # interp with cached control tensors
device = total_control[0].device
for i, (ctrlA, ctrlB) in enumerate(zip(from_control_tensors[interp_ip], to_control_tensors[interp_ip])):
if ctrlA is not None and ctrlB is not None:
ctrlC = weighted_sum(ctrlA.to(device), ctrlB.to(device), interp_alpha)
#print(' ctrl diff:', (ctrlC - total_control[i]).abs().mean().item())
total_control[i].data = ctrlC
interp_ip += 1
# NOTE: warn on T2I adapter
if total_t2i_adapter_embedding[0] != 0:
print(f'{LOG_PREFIX} warn: currently t2i_adapter is not supported. if you wanna this, put a feature request on Kahsolt/stable-diffusion-webui-prompt-travel')
# U-Net Encoder
hs = []
with th.no_grad():
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for i, module in enumerate(self.input_blocks):
h = module(h, emb, context)
if (i + 1) % 3 == 0:
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
hs.append(h)
h = self.middle_block(h, emb, context)
# U-Net Middle Block
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
# U-Net Decoder
for i, module in enumerate(self.output_blocks):
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
h = module(h, emb, context)
# U-Net Output
h = h.type(x.dtype)
h = self.out(h)
# Post-processing for color fix
for param in outer.control_params:
if param.used_hint_cond_latent is None:
continue
if 'colorfix' not in param.preprocessor['name']:
continue
k = int(param.preprocessor['threshold_a'])
if is_in_high_res_fix:
k *= 2
# Inpaint hijack
xt = x[:, :4, :, :]
x0_origin = param.used_hint_cond_latent
t = torch.round(timesteps.float()).long()
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k)
if '+sharp' in param.preprocessor['name']:
detail_weight = float(param.preprocessor['threshold_b']) * 0.01
neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0
x0 = cond_mark * x0 + (1 - cond_mark) * neg
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
w = max(0.0, min(1.0, float(param.weight)))
h = eps_prd * w + h * (1 - w)
# Post-processing for restore
for param in outer.control_params:
if param.used_hint_cond_latent is None:
continue
if 'inpaint_only' not in param.preprocessor['name']:
continue
if param.used_hint_cond.shape[1] != 4:
continue
# Inpaint hijack
xt = x[:, :4, :, :]
mask = param.used_hint_cond[:, 3:4, :, :]
mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1)
x0_origin = param.used_hint_cond_latent
t = torch.round(timesteps.float()).long()
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
x0 = x0_prd * mask + x0_origin * (1 - mask)
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
w = max(0.0, min(1.0, float(param.weight)))
h = eps_prd * w + h * (1 - w)
return h
def forward_webui(*args, **kwargs):
# webui will handle other compoments
try:
if shared.cmd_opts.lowvram:
lowvram.send_everything_to_cpu()
return forward(*args, **kwargs)
finally:
if self.lowvram:
for param in self.control_params:
if isinstance(param.control_model, torch.nn.Module):
param.control_model.to("cpu")
def hacked_basic_transformer_inner_forward(self, x, context=None):
x_norm1 = self.norm1(x)
self_attn1 = None
if self.disable_self_attn:
# Do not use self-attention
self_attn1 = self.attn1(x_norm1, context=context)
else:
# Use self-attention
self_attention_context = x_norm1
if outer.attention_auto_machine == AutoMachine.Write:
if outer.attention_auto_machine_weight > self.attn_weight:
self.bank.append(self_attention_context.detach().clone())
self.style_cfgs.append(outer.current_style_fidelity)
if outer.attention_auto_machine == AutoMachine.Read:
if len(self.bank) > 0:
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1))
self_attn1_c = self_attn1_uc.clone()
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
self_attn1_c[outer.current_uc_indices] = self.attn1(
x_norm1[outer.current_uc_indices],
context=self_attention_context[outer.current_uc_indices])
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
self.bank = []
self.style_cfgs = []
if self_attn1 is None:
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
x = self_attn1.to(x.dtype) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def hacked_group_norm_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
y = None
if outer.gn_auto_machine == AutoMachine.Write:
if outer.gn_auto_machine_weight > self.gn_weight:
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
self.mean_bank.append(mean)
self.var_bank.append(var)
self.style_cfgs.append(outer.current_style_fidelity)
if outer.gn_auto_machine == AutoMachine.Read:
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
var_acc = sum(self.var_bank) / float(len(self.var_bank))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
y_uc = (((x - mean) / std) * std_acc) + mean_acc
y_c = y_uc.clone()
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
y_c[outer.current_uc_indices] = x.to(y_c.dtype)[outer.current_uc_indices]
y = style_cfg * y_c + (1.0 - style_cfg) * y_uc
self.mean_bank = []
self.var_bank = []
self.style_cfgs = []
if y is None:
y = x
return y.to(x.dtype)
if getattr(process, 'sample_before_CN_hack', None) is None:
process.sample_before_CN_hack = process.sample
process.sample = process_sample
model._original_forward = model.forward
outer.original_forward = model.forward
model.forward = forward_webui.__get__(model, UNetModel)
all_modules = torch_dfs(model)
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
if getattr(module, '_original_inner_forward', None) is None:
module._original_inner_forward = module._forward
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
module.bank = []
module.style_cfgs = []
module.attn_weight = float(i) / float(len(attn_modules))
gn_modules = [model.middle_block]
model.middle_block.gn_weight = 0
input_block_indices = [4, 5, 7, 8, 10, 11]
for w, i in enumerate(input_block_indices):
module = model.input_blocks[i]
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
gn_modules.append(module)
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
for w, i in enumerate(output_block_indices):
module = model.output_blocks[i]
module.gn_weight = float(w) / float(len(output_block_indices))
gn_modules.append(module)
for i, module in enumerate(gn_modules):
if getattr(module, 'original_forward', None) is None:
module.original_forward = module.forward
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
module.gn_weight *= 2
outer.attn_module_list = attn_modules
outer.gn_module_list = gn_modules
scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)
# ↑↑↑ the above is modified from 'sd-webui-controlnet/scripts/hook.py' ↑↑↑
def reset_cuda():
devices.torch_gc()
import gc; gc.collect()
try:
import os
import psutil
mem = psutil.Process(os.getpid()).memory_info()
print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB')
from modules.shared import mem_mon as vram_mon
free, total = vram_mon.cuda_mem_get_info()
print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB')
except:
pass
class Script(scripts.Script):
def title(self):
return 'ControlNet Travel'
def describe(self):
return 'Travel from one controlnet hint condition to another in the tensor space.'
def show(self, is_img2img):
return controlnet_found
def ui(self, is_img2img):
with gr.Row(variant='compact'):
interp_meth = gr.Dropdown(label=LABEL_INTERP_METH, value=lambda: DEFAULT_INTERP_METH, choices=CHOICES_INTERP_METH)
steps = gr.Text (label=LABEL_STEPS, value=lambda: DEFAULT_STEPS, max_lines=1)
reset = gr.Button(value='Reset Cuda', variant='tool')
reset.click(fn=reset_cuda, show_progress=False)
with gr.Row(variant='compact'):
ctrlnet_ref_dir = gr.Text(label=LABEL_CTRLNET_REF_DIR, value=lambda: DEFAULT_CTRLNET_REF_DIR, max_lines=1)
with gr.Group(visible=DEFAULT_SKIP_FUSE) as tab_ext_skip_fuse:
with gr.Row(variant='compact'):
skip_in_0 = gr.Checkbox(label='in_0')
skip_in_3 = gr.Checkbox(label='in_3')
skip_out_0 = gr.Checkbox(label='out_0')
skip_out_3 = gr.Checkbox(label='out_3')
with gr.Row(variant='compact'):
skip_in_1 = gr.Checkbox(label='in_1')
skip_in_4 = gr.Checkbox(label='in_4')
skip_out_1 = gr.Checkbox(label='out_1')
skip_out_4 = gr.Checkbox(label='out_4')
with gr.Row(variant='compact'):
skip_in_2 = gr.Checkbox(label='in_2')
skip_in_5 = gr.Checkbox(label='in_5')
skip_out_2 = gr.Checkbox(label='out_2')
skip_out_5 = gr.Checkbox(label='out_5')
with gr.Row(variant='compact'):
skip_mid = gr.Checkbox(label='mid')
with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale:
upscale_meth = gr.Dropdown(label=LABEL_UPSCALE_METH, value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICES_UPSCALER)
upscale_ratio = gr.Slider (label=LABEL_UPSCALE_RATIO, value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=16.0, step=0.1)
upscale_width = gr.Slider (label=LABEL_UPSCALE_WIDTH, value=lambda: DEFAULT_UPSCALE_WIDTH, minimum=0, maximum=2048, step=8)
upscale_height = gr.Slider (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0, maximum=2048, step=8)
with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video:
video_fmt = gr.Dropdown(label=LABEL_VIDEO_FMT, value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT)
video_fps = gr.Number (label=LABEL_VIDEO_FPS, value=lambda: DEFAULT_VIDEO_FPS)
video_pad = gr.Number (label=LABEL_VIDEO_PAD, value=lambda: DEFAULT_VIDEO_PAD, precision=0)
video_pick = gr.Text (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1)
with gr.Row(variant='compact') as tab_ext:
ext_video = gr.Checkbox(label=LABEL_VIDEO, value=lambda: DEFAULT_VIDEO)
ext_upscale = gr.Checkbox(label=LABEL_UPSCALE, value=lambda: DEFAULT_UPSCALE)
ext_skip_fuse = gr.Checkbox(label=LABEL_SKIP_FUSE, value=lambda: DEFAULT_SKIP_FUSE)
dbg_rife = gr.Checkbox(label=LABEL_DEBUG_RIFE, value=lambda: DEFAULT_DEBUG_RIFE)
ext_video .change(gr_show, inputs=ext_video, outputs=tab_ext_video, show_progress=False)
ext_upscale .change(gr_show, inputs=ext_upscale, outputs=tab_ext_upscale, show_progress=False)
ext_skip_fuse.change(gr_show, inputs=ext_skip_fuse, outputs=tab_ext_skip_fuse, show_progress=False)
skip_fuses = [
skip_in_0,
skip_in_1,
skip_in_2,
skip_in_3,
skip_in_4,
skip_in_5,
skip_mid,
skip_out_0,
skip_out_1,
skip_out_2,
skip_out_3,
skip_out_4,
skip_out_5,
]
return [
interp_meth, steps, ctrlnet_ref_dir,
upscale_meth, upscale_ratio, upscale_width, upscale_height,
video_fmt, video_fps, video_pad, video_pick,
ext_video, ext_upscale, ext_skip_fuse, dbg_rife,
*skip_fuses,
]
def run(self, p:Processing,
interp_meth:str, steps:str, ctrlnet_ref_dir:str,
upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int,
video_fmt:str, video_fps:float, video_pad:int, video_pick:str,
ext_video:bool, ext_upscale:bool, ext_skip_fuse:bool, dbg_rife:bool,
*skip_fuses:bool,
):
# Prepare ControlNet
#self.controlnet_script: ControlNetScript = None
self.controlnet_script = None
try:
for script in p.scripts.alwayson_scripts:
if hasattr(script, "latest_network") and script.title().lower() == "controlnet":
script_args: Tuple[ControlNetUnit] = p.script_args[script.args_from:script.args_to]
if not any([u.enabled for u in script_args]): return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not enabled')
self.controlnet_script = script
break
except ImportError:
return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not installed')
except:
print_exc()
if not self.controlnet_script: return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not loaded')
# Enum lookup
interp_meth: InterpMethod = InterpMethod(interp_meth)
video_fmt: VideoFormat = VideoFormat (video_fmt)
# Param check & type convert
if ext_video:
if video_pad < 0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}')
if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}')
try: video_slice = parse_slice(video_pick)
except: return Processed(p, [], p.seed, 'syntax error in video_slice')
if ext_skip_fuse:
global skip_fuse_plan
skip_fuse_plan = skip_fuses
# Prepare ref-images
if not ctrlnet_ref_dir: return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}')
ctrlnet_ref_dir: Path = Path(ctrlnet_ref_dir)
if not ctrlnet_ref_dir.is_dir(): return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}(')
self.ctrlnet_ref_fps = [fp for fp in list(ctrlnet_ref_dir.iterdir()) if fp.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']]
n_stages = len(self.ctrlnet_ref_fps)
if n_stages == 0: return Processed(p, [], p.seed, f'no images file (*.jpg/*.png/*.bmp/*.webp) found in folder path: {ctrlnet_ref_dir}')
if n_stages == 1: return Processed(p, [], p.seed, 'requires at least two images to travel between, but found only 1 :(')
# Prepare steps (n_interp)
try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')]
except: return Processed(p, [], p.seed, f'cannot parse steps options: {steps}')
if len(steps) == 1: steps = [steps[0]] * (n_stages - 1)
elif len(steps) != n_stages - 1: return Processed(p, [], p.seed, f'stage count mismatch: len_steps({len(steps)}) != n_stages({n_stages} - 1))')
n_frames = sum(steps) + n_stages
if 'show_debug':
print('n_stages:', n_stages)
print('n_frames:', n_frames)
print('steps:', steps)
steps.insert(0, -1) # fixup the first stage
# Custom saving path
travel_path = os.path.join(p.outpath_samples, 'prompt_travel')
os.makedirs(travel_path, exist_ok=True)
travel_number = get_next_sequence_number(travel_path)
self.log_dp = os.path.join(travel_path, f'{travel_number:05}')
p.outpath_samples = self.log_dp
os.makedirs(self.log_dp, exist_ok=True)
self.tmp_dp = Path(self.log_dp) / 'ctrl_cond' # cache for rife
self.tmp_fp = self.tmp_dp / 'tmp.png' # cache for rife
# Force Batch Count and Batch Size to 1
p.n_iter = 1
p.batch_size = 1
# Random unified const seed
p.seed = get_fixed_seed(p.seed) # fix it to assure all processes using the same major seed
self.subseed = p.subseed # stash it to allow using random subseed for each process (when -1)
if 'show_debug':
print('seed:', p.seed)
print('subseed:', p.subseed)
print('subseed_strength:', p.subseed_strength)
# Start job
state.job_count = n_frames
# Pack params
self.n_stages = n_stages
self.steps = steps
self.interp_meth = interp_meth
self.dbg_rife = dbg_rife
def upscale_image_callback(params:ImageSaveParams):
params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height)
images: List[PILImage] = []
info: str = None
try:
if ext_upscale: on_before_image_saved(upscale_image_callback)
self.UnetHook_hook_original = UnetHook.hook
UnetHook.hook = hook_hijack
[c.clear() for c in caches]
images, info = self.run_linear(p)
except:
info = format_exc()
print(info)
finally:
if self.tmp_fp.exists(): os.unlink(self.tmp_fp)
[c.clear() for c in caches]
UnetHook.hook = self.UnetHook_hook_original
self.controlnet_script.input_image = None
if self.controlnet_script.latest_network:
self.controlnet_script.latest_network: UnetHook
self.controlnet_script.latest_network.restore(p.sd_model.model.diffusion_model)
self.controlnet_script.latest_network = None
if ext_upscale: remove_callbacks_for_function(upscale_image_callback)
reset_cuda()
# Save video
if ext_video: save_video(images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}'))
return Processed(p, images, p.seed, info)
def run_linear(self, p:Processing) -> RunResults:
global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, interp_alpha, interp_ip
images: List[PILImage] = []
info: str = None
def process_p(append:bool=True) -> Optional[List[PILImage]]:
nonlocal p, images, info
proc = process_images(p)
if not info: info = proc.info
if append: images.extend(proc.images)
else: return proc.images
''' ↓↓↓ rife interp utils ↓↓↓ '''
def save_ctrl_cond(idx:int):
self.tmp_dp.mkdir(exist_ok=True)
for i, x in enumerate(to_hint_cond):
x = x[0]
if len(x.shape) == 3:
if x.shape[0] == 1: x = x.squeeze_(0) # [C=1, H, W] => [H, W]
elif x.shape[0] == 3: x = x.permute([1, 2, 0]) # [C=3, H, W] => [H, W, C]
else: raise ValueError(f'unknown cond shape: {x.shape}')
else:
raise ValueError(f'unknown cond shape: {x.shape}')
im = (x.detach().clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
Image.fromarray(im).save(self.tmp_dp / f'{idx}-{i}.png')
def rife_interp(i:int, j:int, k:int, alpha:float) -> Tensor:
''' interp between i-th and j-th cond of the k-th ctrlnet set '''
fp0 = self.tmp_dp / f'{i}-{k}.png'
fp1 = self.tmp_dp / f'{j}-{k}.png'
fpo = self.tmp_dp / f'{i}-{j}-{alpha:.3f}.png' if self.dbg_rife else self.tmp_fp
assert run_cmd(f'rife-ncnn-vulkan -m rife-v4 -s {alpha:.3f} -0 "{fp0}" -1 "{fp1}" -o "{fpo}"')
x = torch.from_numpy(np.asarray(Image.open(fpo)) / 255.0)
if len(x.shape) == 2: x = x.unsqueeze_(0) # [H, W] => [C=1, H, W]
elif len(x.shape) == 3: x = x.permute([2, 0, 1]) # [H, W, C] => [C, H, W]
else: raise ValueError(f'unknown cond shape: {x.shape}')
x = x.unsqueeze(dim=0)
return x
''' ↑↑↑ rife interp utils ↑↑↑ '''
''' ↓↓↓ filename reorder utils ↓↓↓ '''
iframe = 0
def rename_image_filename(idx:int, param: ImageSaveParams):
fn = param.filename
stem, suffix = os.path.splitext(os.path.basename(fn))
param.filename = os.path.join(os.path.dirname(fn), f'{idx:05d}' + suffix)
class on_before_image_saved_wrapper:
def __init__(self, callback_fn):
self.callback_fn = callback_fn
def __enter__(self):
on_before_image_saved(self.callback_fn)
def __exit__(self, exc_type, exc_value, exc_traceback):
remove_callbacks_for_function(self.callback_fn)
''' ↑↑↑ filename reorder utils ↑↑↑ '''
# Step 1: draw the init image
setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[0])])
interp_alpha = 0.0
with on_before_image_saved_wrapper(partial(rename_image_filename, 0)):
process_p()
iframe += 1
save_ctrl_cond(0)
# travel through stages
for i in range(1, self.n_stages):
if state.interrupted: break
# Setp 3: move to next stage
from_hint_cond = [t for t in to_hint_cond] ; to_hint_cond .clear()
from_control_tensors = [t for t in to_control_tensors] ; to_control_tensors.clear()
setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[i])])
interp_alpha = 0.0
with on_before_image_saved_wrapper(partial(rename_image_filename, iframe + self.steps[i])):
cached_images = process_p(append=False)
save_ctrl_cond(i)
# Step 2: draw the interpolated images
is_interrupted = False
n_inter = self.steps[i] + 1
for t in range(1, n_inter):
if state.interrupted: is_interrupted = True ; break
interp_alpha = t / n_inter # [1/T, 2/T, .. T-1/T]
mid_hint_cond.clear()
device = devices.get_device_for("controlnet")
if self.interp_meth == InterpMethod.LINEAR:
for hintA, hintB in zip(from_hint_cond, to_hint_cond):
hintC = weighted_sum(hintA.to(device), hintB.to(device), interp_alpha)
mid_hint_cond.append(hintC)
elif self.interp_meth == InterpMethod.RIFE:
dtype = to_hint_cond[0].dtype
for k in range(len(to_hint_cond)):
hintC = rife_interp(i-1, i, k, interp_alpha).to(device, dtype)
mid_hint_cond.append(hintC)
else: raise ValueError(f'unknown interp_meth: {self.interp_meth}')
interp_ip = 0
with on_before_image_saved_wrapper(partial(rename_image_filename, iframe)):
process_p()
iframe += 1
# adjust order
images.extend(cached_images)
iframe += 1
if is_interrupted: break
return images, info