|
|
|
|
|
|
|
LOG_PREFIX = '[ControlNet-Travel]' |
|
|
|
|
|
|
|
CTRLNET_REPO_NAME = 'sdcontrol' |
|
if 'externel repo sanity check': |
|
from pathlib import Path |
|
from modules.scripts import basedir |
|
from traceback import print_exc |
|
|
|
ME_PATH = Path(basedir()) |
|
CTRLNET_PATH = ME_PATH.parent / 'sdcontrol' |
|
|
|
controlnet_found = False |
|
try: |
|
import sys ; sys.path.append(str(CTRLNET_PATH)) |
|
|
|
from scripts.external_code import ControlNetUnit |
|
from scripts.hook import UNetModel, UnetHook, ControlParams |
|
from scripts.hook import * |
|
|
|
controlnet_found = True |
|
print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} found, ControlNet-Travel loaded :)') |
|
except ImportError: |
|
print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} not found, ControlNet-Travel ignored :(') |
|
exit(0) |
|
except: |
|
print_exc() |
|
exit(0) |
|
|
|
|
|
|
|
|
|
import sys |
|
from PIL import Image |
|
|
|
from ldm.models.diffusion.ddpm import LatentDiffusion |
|
from modules import shared, devices, lowvram |
|
from modules.processing import StableDiffusionProcessing as Processing |
|
|
|
from scripts.prompt_travel import * |
|
from manager import run_cmd |
|
|
|
class InterpMethod(Enum): |
|
LINEAR = 'linear (weight sum)' |
|
RIFE = 'rife (optical flow)' |
|
|
|
if 'consts': |
|
__ = lambda key, value=None: opts.data.get(f'customscript/controlnet_travel.py/txt2img/{key}/value', value) |
|
|
|
|
|
LABEL_CTRLNET_REF_DIR = 'Reference image folder (one ref image per stage :)' |
|
LABEL_INTERP_METH = 'Interpolate method' |
|
LABEL_SKIP_FUSE = 'Ext. skip latent fusion' |
|
LABEL_DEBUG_RIFE = 'Save RIFE intermediates' |
|
|
|
DEFAULT_STEPS = 10 |
|
DEFAULT_CTRLNET_REF_DIR = str(ME_PATH / 'img' / 'ref_ctrlnet') |
|
DEFAULT_INTERP_METH = __(LABEL_INTERP_METH, InterpMethod.LINEAR.value) |
|
DEFAULT_SKIP_FUSE = __(LABEL_SKIP_FUSE, False) |
|
DEFAULT_DEBUG_RIFE = __(LABEL_DEBUG_RIFE, False) |
|
|
|
CHOICES_INTERP_METH = [x.value for x in InterpMethod] |
|
|
|
if 'vars': |
|
skip_fuse_plan: List[bool] = [] |
|
|
|
interp_alpha: float = 0.0 |
|
interp_ip: int = 0 |
|
from_hint_cond: List[Tensor] = [] |
|
to_hint_cond: List[Tensor] = [] |
|
mid_hint_cond: List[Tensor] = [] |
|
from_control_tensors: List[List[Tensor]] = [] |
|
to_control_tensors: List[List[Tensor]] = [] |
|
|
|
caches: List[list] = [from_hint_cond, to_hint_cond, mid_hint_cond, from_control_tensors, to_control_tensors] |
|
|
|
|
|
|
|
|
|
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing): |
|
self.model = model |
|
self.sd_ldm = sd_ldm |
|
self.control_params = control_params |
|
|
|
outer = self |
|
|
|
def process_sample(*args, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mark_prompt_context(kwargs.get('conditioning', []), positive=True) |
|
mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False) |
|
mark_prompt_context(getattr(process, 'hr_c', []), positive=True) |
|
mark_prompt_context(getattr(process, 'hr_uc', []), positive=False) |
|
return process.sample_before_CN_hack(*args, **kwargs) |
|
|
|
|
|
def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs): |
|
total_controlnet_embedding = [0.0] * 13 |
|
total_t2i_adapter_embedding = [0.0] * 4 |
|
require_inpaint_hijack = False |
|
is_in_high_res_fix = False |
|
batch_size = int(x.shape[0]) |
|
|
|
|
|
global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, mid_hint_cond, interp_alpha, interp_ip |
|
x: Tensor |
|
timesteps: Tensor |
|
context: Tensor |
|
kwargs: dict |
|
|
|
|
|
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context) |
|
|
|
|
|
|
|
for param in outer.control_params: |
|
|
|
if param.used_hint_cond is None: |
|
param.used_hint_cond = param.hint_cond |
|
param.used_hint_cond_latent = None |
|
param.used_hint_inpaint_hijack = None |
|
|
|
|
|
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4: |
|
_, _, h_lr, w_lr = param.hint_cond.shape |
|
_, _, h_hr, w_hr = param.hr_hint_cond.shape |
|
_, _, h, w = x.shape |
|
h, w = h * 8, w * 8 |
|
if abs(h - h_lr) < abs(h - h_hr): |
|
is_in_high_res_fix = False |
|
if param.used_hint_cond is not param.hint_cond: |
|
param.used_hint_cond = param.hint_cond |
|
param.used_hint_cond_latent = None |
|
param.used_hint_inpaint_hijack = None |
|
else: |
|
is_in_high_res_fix = True |
|
if param.used_hint_cond is not param.hr_hint_cond: |
|
param.used_hint_cond = param.hr_hint_cond |
|
param.used_hint_cond_latent = None |
|
param.used_hint_inpaint_hijack = None |
|
|
|
|
|
for i, param in enumerate(outer.control_params): |
|
if interp_alpha == 0.0: |
|
if len(to_hint_cond) < len(outer.control_params): |
|
to_hint_cond.append(param.used_hint_cond.clone().detach().cpu()) |
|
else: |
|
param.used_hint_cond = mid_hint_cond[i].to(x.device) |
|
|
|
|
|
for param in outer.control_params: |
|
if param.used_hint_cond_latent is not None: |
|
continue |
|
if param.control_model_type not in [ControlModelType.AttentionInjection] \ |
|
and 'colorfix' not in param.preprocessor['name'] \ |
|
and 'inpaint_only' not in param.preprocessor['name']: |
|
continue |
|
param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size) |
|
|
|
|
|
for param in outer.control_params: |
|
if param.guidance_stopped: |
|
continue |
|
|
|
if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]: |
|
continue |
|
|
|
param.control_model.to(devices.get_device_for("controlnet")) |
|
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context) |
|
control = torch.cat([control.clone() for _ in range(batch_size)], dim=0) |
|
control *= param.weight |
|
control *= cond_mark[:, :, :, 0] |
|
context = torch.cat([context, control.clone()], dim=1) |
|
|
|
|
|
for param in outer.control_params: |
|
if param.guidance_stopped: |
|
continue |
|
|
|
if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]: |
|
continue |
|
|
|
param.control_model.to(devices.get_device_for("controlnet")) |
|
|
|
x_in = x |
|
control_model = param.control_model.control_model |
|
|
|
if param.control_model_type == ControlModelType.ControlNet: |
|
if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9: |
|
|
|
x_in = x[:, :4, ...] |
|
require_inpaint_hijack = True |
|
|
|
assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given" |
|
|
|
hint = param.used_hint_cond |
|
|
|
|
|
if hint.shape[1] == 4: |
|
c = hint[:, 0:3, :, :] |
|
m = hint[:, 3:4, :, :] |
|
m = (m > 0.5).float() |
|
hint = c * (1 - m) - m |
|
|
|
|
|
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context) |
|
control_scales = ([param.weight] * 13) |
|
|
|
if outer.lowvram: |
|
param.control_model.to("cpu") |
|
|
|
if param.cfg_injection or param.global_average_pooling: |
|
if param.control_model_type == ControlModelType.T2I_Adapter: |
|
control = [torch.cat([c.clone() for _ in range(batch_size)], dim=0) for c in control] |
|
control = [c * cond_mark for c in control] |
|
|
|
high_res_fix_forced_soft_injection = False |
|
|
|
if is_in_high_res_fix: |
|
if 'canny' in param.preprocessor['name']: |
|
high_res_fix_forced_soft_injection = True |
|
if 'mlsd' in param.preprocessor['name']: |
|
high_res_fix_forced_soft_injection = True |
|
|
|
|
|
|
|
|
|
if param.soft_injection or high_res_fix_forced_soft_injection: |
|
|
|
if param.control_model_type == ControlModelType.T2I_Adapter: |
|
control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)] |
|
elif param.control_model_type == ControlModelType.ControlNet: |
|
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)] |
|
|
|
if param.advanced_weighting is not None: |
|
control_scales = param.advanced_weighting |
|
|
|
control = [c * scale for c, scale in zip(control, control_scales)] |
|
if param.global_average_pooling: |
|
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control] |
|
|
|
for idx, item in enumerate(control): |
|
target = None |
|
if param.control_model_type == ControlModelType.ControlNet: |
|
target = total_controlnet_embedding |
|
if param.control_model_type == ControlModelType.T2I_Adapter: |
|
target = total_t2i_adapter_embedding |
|
if target is not None: |
|
target[idx] = item + target[idx] |
|
|
|
|
|
for param in outer.control_params: |
|
if param.used_hint_cond.shape[1] != 4: |
|
continue |
|
if x.shape[1] != 9: |
|
continue |
|
if param.used_hint_inpaint_hijack is None: |
|
mask_pixel = param.used_hint_cond[:, 3:4, :, :] |
|
image_pixel = param.used_hint_cond[:, 0:3, :, :] |
|
mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype) |
|
masked_latent = outer.call_vae_using_process(process, image_pixel, batch_size, mask=mask_pixel) |
|
mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8)) |
|
if mask_latent.shape[0] != batch_size: |
|
mask_latent = torch.cat([mask_latent.clone() for _ in range(batch_size)], dim=0) |
|
param.used_hint_inpaint_hijack = torch.cat([mask_latent, masked_latent], dim=1) |
|
param.used_hint_inpaint_hijack.to(x.dtype).to(x.device) |
|
x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1) |
|
|
|
|
|
if shared.cmd_opts.medvram: |
|
try: |
|
|
|
outer.sd_ldm.model() |
|
except: |
|
pass |
|
|
|
|
|
for module in outer.attn_module_list: |
|
module.bank = [] |
|
module.style_cfgs = [] |
|
for module in outer.gn_module_list: |
|
module.mean_bank = [] |
|
module.var_bank = [] |
|
module.style_cfgs = [] |
|
|
|
|
|
for param in outer.control_params: |
|
if param.guidance_stopped: |
|
continue |
|
|
|
if param.used_hint_cond_latent is None: |
|
continue |
|
|
|
if param.control_model_type not in [ControlModelType.AttentionInjection]: |
|
continue |
|
|
|
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long()) |
|
|
|
|
|
if x.shape[1] == 9: |
|
ref_xt = torch.cat([ |
|
ref_xt, |
|
torch.zeros_like(ref_xt)[:, 0:1, :, :], |
|
param.used_hint_cond_latent |
|
], dim=1) |
|
|
|
outer.current_style_fidelity = float(param.preprocessor['threshold_a']) |
|
outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity)) |
|
|
|
if param.cfg_injection: |
|
outer.current_style_fidelity = 1.0 |
|
elif param.soft_injection or is_in_high_res_fix: |
|
outer.current_style_fidelity = 0.0 |
|
|
|
control_name = param.preprocessor['name'] |
|
|
|
if control_name in ['reference_only', 'reference_adain+attn']: |
|
outer.attention_auto_machine = AutoMachine.Write |
|
outer.attention_auto_machine_weight = param.weight |
|
|
|
if control_name in ['reference_adain', 'reference_adain+attn']: |
|
outer.gn_auto_machine = AutoMachine.Write |
|
outer.gn_auto_machine_weight = param.weight |
|
|
|
outer.original_forward( |
|
x=ref_xt.to(devices.dtype_unet), |
|
timesteps=timesteps.to(devices.dtype_unet), |
|
context=context.to(devices.dtype_unet) |
|
) |
|
|
|
outer.attention_auto_machine = AutoMachine.Read |
|
outer.gn_auto_machine = AutoMachine.Read |
|
|
|
|
|
total_control = total_controlnet_embedding |
|
if interp_alpha == 0.0: |
|
tensors: List[Tensor] = [] |
|
for i, t in enumerate(total_control): |
|
if len(skip_fuse_plan) and skip_fuse_plan[i]: |
|
tensors.append(None) |
|
else: |
|
tensors.append(t.clone().detach().cpu()) |
|
to_control_tensors.append(tensors) |
|
else: |
|
device = total_control[0].device |
|
for i, (ctrlA, ctrlB) in enumerate(zip(from_control_tensors[interp_ip], to_control_tensors[interp_ip])): |
|
if ctrlA is not None and ctrlB is not None: |
|
ctrlC = weighted_sum(ctrlA.to(device), ctrlB.to(device), interp_alpha) |
|
|
|
total_control[i].data = ctrlC |
|
interp_ip += 1 |
|
|
|
|
|
if total_t2i_adapter_embedding[0] != 0: |
|
print(f'{LOG_PREFIX} warn: currently t2i_adapter is not supported. if you wanna this, put a feature request on Kahsolt/stable-diffusion-webui-prompt-travel') |
|
|
|
|
|
hs = [] |
|
with th.no_grad(): |
|
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False)) |
|
emb = self.time_embed(t_emb) |
|
h = x.type(self.dtype) |
|
for i, module in enumerate(self.input_blocks): |
|
h = module(h, emb, context) |
|
|
|
if (i + 1) % 3 == 0: |
|
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack) |
|
|
|
hs.append(h) |
|
h = self.middle_block(h, emb, context) |
|
|
|
|
|
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) |
|
|
|
|
|
for i, module in enumerate(self.output_blocks): |
|
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1) |
|
h = module(h, emb, context) |
|
|
|
|
|
h = h.type(x.dtype) |
|
h = self.out(h) |
|
|
|
|
|
for param in outer.control_params: |
|
if param.used_hint_cond_latent is None: |
|
continue |
|
if 'colorfix' not in param.preprocessor['name']: |
|
continue |
|
|
|
k = int(param.preprocessor['threshold_a']) |
|
if is_in_high_res_fix: |
|
k *= 2 |
|
|
|
|
|
xt = x[:, :4, :, :] |
|
|
|
x0_origin = param.used_hint_cond_latent |
|
t = torch.round(timesteps.float()).long() |
|
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) |
|
x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k) |
|
|
|
if '+sharp' in param.preprocessor['name']: |
|
detail_weight = float(param.preprocessor['threshold_b']) * 0.01 |
|
neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0 |
|
x0 = cond_mark * x0 + (1 - cond_mark) * neg |
|
|
|
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) |
|
|
|
w = max(0.0, min(1.0, float(param.weight))) |
|
h = eps_prd * w + h * (1 - w) |
|
|
|
|
|
for param in outer.control_params: |
|
if param.used_hint_cond_latent is None: |
|
continue |
|
if 'inpaint_only' not in param.preprocessor['name']: |
|
continue |
|
if param.used_hint_cond.shape[1] != 4: |
|
continue |
|
|
|
|
|
xt = x[:, :4, :, :] |
|
|
|
mask = param.used_hint_cond[:, 3:4, :, :] |
|
mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1) |
|
|
|
x0_origin = param.used_hint_cond_latent |
|
t = torch.round(timesteps.float()).long() |
|
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) |
|
x0 = x0_prd * mask + x0_origin * (1 - mask) |
|
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) |
|
|
|
w = max(0.0, min(1.0, float(param.weight))) |
|
h = eps_prd * w + h * (1 - w) |
|
|
|
return h |
|
|
|
def forward_webui(*args, **kwargs): |
|
|
|
try: |
|
if shared.cmd_opts.lowvram: |
|
lowvram.send_everything_to_cpu() |
|
|
|
return forward(*args, **kwargs) |
|
finally: |
|
if self.lowvram: |
|
for param in self.control_params: |
|
if isinstance(param.control_model, torch.nn.Module): |
|
param.control_model.to("cpu") |
|
|
|
def hacked_basic_transformer_inner_forward(self, x, context=None): |
|
x_norm1 = self.norm1(x) |
|
self_attn1 = None |
|
if self.disable_self_attn: |
|
|
|
self_attn1 = self.attn1(x_norm1, context=context) |
|
else: |
|
|
|
self_attention_context = x_norm1 |
|
if outer.attention_auto_machine == AutoMachine.Write: |
|
if outer.attention_auto_machine_weight > self.attn_weight: |
|
self.bank.append(self_attention_context.detach().clone()) |
|
self.style_cfgs.append(outer.current_style_fidelity) |
|
if outer.attention_auto_machine == AutoMachine.Read: |
|
if len(self.bank) > 0: |
|
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) |
|
self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1)) |
|
self_attn1_c = self_attn1_uc.clone() |
|
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: |
|
self_attn1_c[outer.current_uc_indices] = self.attn1( |
|
x_norm1[outer.current_uc_indices], |
|
context=self_attention_context[outer.current_uc_indices]) |
|
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc |
|
self.bank = [] |
|
self.style_cfgs = [] |
|
if self_attn1 is None: |
|
self_attn1 = self.attn1(x_norm1, context=self_attention_context) |
|
|
|
x = self_attn1.to(x.dtype) + x |
|
x = self.attn2(self.norm2(x), context=context) + x |
|
x = self.ff(self.norm3(x)) + x |
|
return x |
|
|
|
def hacked_group_norm_forward(self, *args, **kwargs): |
|
eps = 1e-6 |
|
x = self.original_forward(*args, **kwargs) |
|
y = None |
|
if outer.gn_auto_machine == AutoMachine.Write: |
|
if outer.gn_auto_machine_weight > self.gn_weight: |
|
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
|
self.mean_bank.append(mean) |
|
self.var_bank.append(var) |
|
self.style_cfgs.append(outer.current_style_fidelity) |
|
if outer.gn_auto_machine == AutoMachine.Read: |
|
if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
|
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) |
|
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) |
|
var_acc = sum(self.var_bank) / float(len(self.var_bank)) |
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
y_uc = (((x - mean) / std) * std_acc) + mean_acc |
|
y_c = y_uc.clone() |
|
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: |
|
y_c[outer.current_uc_indices] = x.to(y_c.dtype)[outer.current_uc_indices] |
|
y = style_cfg * y_c + (1.0 - style_cfg) * y_uc |
|
self.mean_bank = [] |
|
self.var_bank = [] |
|
self.style_cfgs = [] |
|
if y is None: |
|
y = x |
|
return y.to(x.dtype) |
|
|
|
if getattr(process, 'sample_before_CN_hack', None) is None: |
|
process.sample_before_CN_hack = process.sample |
|
process.sample = process_sample |
|
|
|
model._original_forward = model.forward |
|
outer.original_forward = model.forward |
|
model.forward = forward_webui.__get__(model, UNetModel) |
|
|
|
all_modules = torch_dfs(model) |
|
|
|
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] |
|
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0]) |
|
|
|
for i, module in enumerate(attn_modules): |
|
if getattr(module, '_original_inner_forward', None) is None: |
|
module._original_inner_forward = module._forward |
|
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) |
|
module.bank = [] |
|
module.style_cfgs = [] |
|
module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
|
gn_modules = [model.middle_block] |
|
model.middle_block.gn_weight = 0 |
|
|
|
input_block_indices = [4, 5, 7, 8, 10, 11] |
|
for w, i in enumerate(input_block_indices): |
|
module = model.input_blocks[i] |
|
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) |
|
gn_modules.append(module) |
|
|
|
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] |
|
for w, i in enumerate(output_block_indices): |
|
module = model.output_blocks[i] |
|
module.gn_weight = float(w) / float(len(output_block_indices)) |
|
gn_modules.append(module) |
|
|
|
for i, module in enumerate(gn_modules): |
|
if getattr(module, 'original_forward', None) is None: |
|
module.original_forward = module.forward |
|
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module) |
|
module.mean_bank = [] |
|
module.var_bank = [] |
|
module.style_cfgs = [] |
|
module.gn_weight *= 2 |
|
|
|
outer.attn_module_list = attn_modules |
|
outer.gn_module_list = gn_modules |
|
|
|
scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler) |
|
|
|
|
|
|
|
def reset_cuda(): |
|
devices.torch_gc() |
|
import gc; gc.collect() |
|
|
|
try: |
|
import os |
|
import psutil |
|
mem = psutil.Process(os.getpid()).memory_info() |
|
print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') |
|
from modules.shared import mem_mon as vram_mon |
|
free, total = vram_mon.cuda_mem_get_info() |
|
print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') |
|
except: |
|
pass |
|
|
|
|
|
class Script(scripts.Script): |
|
|
|
def title(self): |
|
return 'ControlNet Travel' |
|
|
|
def describe(self): |
|
return 'Travel from one controlnet hint condition to another in the tensor space.' |
|
|
|
def show(self, is_img2img): |
|
return controlnet_found |
|
|
|
def ui(self, is_img2img): |
|
with gr.Row(variant='compact'): |
|
interp_meth = gr.Dropdown(label=LABEL_INTERP_METH, value=lambda: DEFAULT_INTERP_METH, choices=CHOICES_INTERP_METH) |
|
steps = gr.Text (label=LABEL_STEPS, value=lambda: DEFAULT_STEPS, max_lines=1) |
|
|
|
reset = gr.Button(value='Reset Cuda', variant='tool') |
|
reset.click(fn=reset_cuda, show_progress=False) |
|
|
|
with gr.Row(variant='compact'): |
|
ctrlnet_ref_dir = gr.Text(label=LABEL_CTRLNET_REF_DIR, value=lambda: DEFAULT_CTRLNET_REF_DIR, max_lines=1) |
|
|
|
with gr.Group(visible=DEFAULT_SKIP_FUSE) as tab_ext_skip_fuse: |
|
with gr.Row(variant='compact'): |
|
skip_in_0 = gr.Checkbox(label='in_0') |
|
skip_in_3 = gr.Checkbox(label='in_3') |
|
skip_out_0 = gr.Checkbox(label='out_0') |
|
skip_out_3 = gr.Checkbox(label='out_3') |
|
with gr.Row(variant='compact'): |
|
skip_in_1 = gr.Checkbox(label='in_1') |
|
skip_in_4 = gr.Checkbox(label='in_4') |
|
skip_out_1 = gr.Checkbox(label='out_1') |
|
skip_out_4 = gr.Checkbox(label='out_4') |
|
with gr.Row(variant='compact'): |
|
skip_in_2 = gr.Checkbox(label='in_2') |
|
skip_in_5 = gr.Checkbox(label='in_5') |
|
skip_out_2 = gr.Checkbox(label='out_2') |
|
skip_out_5 = gr.Checkbox(label='out_5') |
|
with gr.Row(variant='compact'): |
|
skip_mid = gr.Checkbox(label='mid') |
|
|
|
with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale: |
|
upscale_meth = gr.Dropdown(label=LABEL_UPSCALE_METH, value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICES_UPSCALER) |
|
upscale_ratio = gr.Slider (label=LABEL_UPSCALE_RATIO, value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=16.0, step=0.1) |
|
upscale_width = gr.Slider (label=LABEL_UPSCALE_WIDTH, value=lambda: DEFAULT_UPSCALE_WIDTH, minimum=0, maximum=2048, step=8) |
|
upscale_height = gr.Slider (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0, maximum=2048, step=8) |
|
|
|
with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video: |
|
video_fmt = gr.Dropdown(label=LABEL_VIDEO_FMT, value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT) |
|
video_fps = gr.Number (label=LABEL_VIDEO_FPS, value=lambda: DEFAULT_VIDEO_FPS) |
|
video_pad = gr.Number (label=LABEL_VIDEO_PAD, value=lambda: DEFAULT_VIDEO_PAD, precision=0) |
|
video_pick = gr.Text (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1) |
|
|
|
with gr.Row(variant='compact') as tab_ext: |
|
ext_video = gr.Checkbox(label=LABEL_VIDEO, value=lambda: DEFAULT_VIDEO) |
|
ext_upscale = gr.Checkbox(label=LABEL_UPSCALE, value=lambda: DEFAULT_UPSCALE) |
|
ext_skip_fuse = gr.Checkbox(label=LABEL_SKIP_FUSE, value=lambda: DEFAULT_SKIP_FUSE) |
|
dbg_rife = gr.Checkbox(label=LABEL_DEBUG_RIFE, value=lambda: DEFAULT_DEBUG_RIFE) |
|
|
|
ext_video .change(gr_show, inputs=ext_video, outputs=tab_ext_video, show_progress=False) |
|
ext_upscale .change(gr_show, inputs=ext_upscale, outputs=tab_ext_upscale, show_progress=False) |
|
ext_skip_fuse.change(gr_show, inputs=ext_skip_fuse, outputs=tab_ext_skip_fuse, show_progress=False) |
|
|
|
skip_fuses = [ |
|
skip_in_0, |
|
skip_in_1, |
|
skip_in_2, |
|
skip_in_3, |
|
skip_in_4, |
|
skip_in_5, |
|
skip_mid, |
|
skip_out_0, |
|
skip_out_1, |
|
skip_out_2, |
|
skip_out_3, |
|
skip_out_4, |
|
skip_out_5, |
|
] |
|
return [ |
|
interp_meth, steps, ctrlnet_ref_dir, |
|
upscale_meth, upscale_ratio, upscale_width, upscale_height, |
|
video_fmt, video_fps, video_pad, video_pick, |
|
ext_video, ext_upscale, ext_skip_fuse, dbg_rife, |
|
*skip_fuses, |
|
] |
|
|
|
def run(self, p:Processing, |
|
interp_meth:str, steps:str, ctrlnet_ref_dir:str, |
|
upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int, |
|
video_fmt:str, video_fps:float, video_pad:int, video_pick:str, |
|
ext_video:bool, ext_upscale:bool, ext_skip_fuse:bool, dbg_rife:bool, |
|
*skip_fuses:bool, |
|
): |
|
|
|
|
|
|
|
self.controlnet_script = None |
|
try: |
|
for script in p.scripts.alwayson_scripts: |
|
if hasattr(script, "latest_network") and script.title().lower() == "controlnet": |
|
script_args: Tuple[ControlNetUnit] = p.script_args[script.args_from:script.args_to] |
|
if not any([u.enabled for u in script_args]): return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not enabled') |
|
self.controlnet_script = script |
|
break |
|
except ImportError: |
|
return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not installed') |
|
except: |
|
print_exc() |
|
if not self.controlnet_script: return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not loaded') |
|
|
|
|
|
interp_meth: InterpMethod = InterpMethod(interp_meth) |
|
video_fmt: VideoFormat = VideoFormat (video_fmt) |
|
|
|
|
|
if ext_video: |
|
if video_pad < 0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}') |
|
if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}') |
|
try: video_slice = parse_slice(video_pick) |
|
except: return Processed(p, [], p.seed, 'syntax error in video_slice') |
|
if ext_skip_fuse: |
|
global skip_fuse_plan |
|
skip_fuse_plan = skip_fuses |
|
|
|
|
|
if not ctrlnet_ref_dir: return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}') |
|
ctrlnet_ref_dir: Path = Path(ctrlnet_ref_dir) |
|
if not ctrlnet_ref_dir.is_dir(): return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}(') |
|
self.ctrlnet_ref_fps = [fp for fp in list(ctrlnet_ref_dir.iterdir()) if fp.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']] |
|
n_stages = len(self.ctrlnet_ref_fps) |
|
if n_stages == 0: return Processed(p, [], p.seed, f'no images file (*.jpg/*.png/*.bmp/*.webp) found in folder path: {ctrlnet_ref_dir}') |
|
if n_stages == 1: return Processed(p, [], p.seed, 'requires at least two images to travel between, but found only 1 :(') |
|
|
|
|
|
try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')] |
|
except: return Processed(p, [], p.seed, f'cannot parse steps options: {steps}') |
|
if len(steps) == 1: steps = [steps[0]] * (n_stages - 1) |
|
elif len(steps) != n_stages - 1: return Processed(p, [], p.seed, f'stage count mismatch: len_steps({len(steps)}) != n_stages({n_stages} - 1))') |
|
n_frames = sum(steps) + n_stages |
|
if 'show_debug': |
|
print('n_stages:', n_stages) |
|
print('n_frames:', n_frames) |
|
print('steps:', steps) |
|
steps.insert(0, -1) |
|
|
|
|
|
travel_path = os.path.join(p.outpath_samples, 'prompt_travel') |
|
os.makedirs(travel_path, exist_ok=True) |
|
travel_number = get_next_sequence_number(travel_path) |
|
self.log_dp = os.path.join(travel_path, f'{travel_number:05}') |
|
p.outpath_samples = self.log_dp |
|
os.makedirs(self.log_dp, exist_ok=True) |
|
self.tmp_dp = Path(self.log_dp) / 'ctrl_cond' |
|
self.tmp_fp = self.tmp_dp / 'tmp.png' |
|
|
|
|
|
p.n_iter = 1 |
|
p.batch_size = 1 |
|
|
|
|
|
p.seed = get_fixed_seed(p.seed) |
|
self.subseed = p.subseed |
|
if 'show_debug': |
|
print('seed:', p.seed) |
|
print('subseed:', p.subseed) |
|
print('subseed_strength:', p.subseed_strength) |
|
|
|
|
|
state.job_count = n_frames |
|
|
|
|
|
self.n_stages = n_stages |
|
self.steps = steps |
|
self.interp_meth = interp_meth |
|
self.dbg_rife = dbg_rife |
|
|
|
def upscale_image_callback(params:ImageSaveParams): |
|
params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height) |
|
|
|
images: List[PILImage] = [] |
|
info: str = None |
|
try: |
|
if ext_upscale: on_before_image_saved(upscale_image_callback) |
|
|
|
self.UnetHook_hook_original = UnetHook.hook |
|
UnetHook.hook = hook_hijack |
|
|
|
[c.clear() for c in caches] |
|
images, info = self.run_linear(p) |
|
except: |
|
info = format_exc() |
|
print(info) |
|
finally: |
|
if self.tmp_fp.exists(): os.unlink(self.tmp_fp) |
|
[c.clear() for c in caches] |
|
|
|
UnetHook.hook = self.UnetHook_hook_original |
|
|
|
self.controlnet_script.input_image = None |
|
if self.controlnet_script.latest_network: |
|
self.controlnet_script.latest_network: UnetHook |
|
self.controlnet_script.latest_network.restore(p.sd_model.model.diffusion_model) |
|
self.controlnet_script.latest_network = None |
|
|
|
if ext_upscale: remove_callbacks_for_function(upscale_image_callback) |
|
|
|
reset_cuda() |
|
|
|
|
|
if ext_video: save_video(images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}')) |
|
|
|
return Processed(p, images, p.seed, info) |
|
|
|
def run_linear(self, p:Processing) -> RunResults: |
|
global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, interp_alpha, interp_ip |
|
|
|
images: List[PILImage] = [] |
|
info: str = None |
|
def process_p(append:bool=True) -> Optional[List[PILImage]]: |
|
nonlocal p, images, info |
|
proc = process_images(p) |
|
if not info: info = proc.info |
|
if append: images.extend(proc.images) |
|
else: return proc.images |
|
|
|
''' βββ rife interp utils βββ ''' |
|
def save_ctrl_cond(idx:int): |
|
self.tmp_dp.mkdir(exist_ok=True) |
|
for i, x in enumerate(to_hint_cond): |
|
x = x[0] |
|
if len(x.shape) == 3: |
|
if x.shape[0] == 1: x = x.squeeze_(0) |
|
elif x.shape[0] == 3: x = x.permute([1, 2, 0]) |
|
else: raise ValueError(f'unknown cond shape: {x.shape}') |
|
else: |
|
raise ValueError(f'unknown cond shape: {x.shape}') |
|
im = (x.detach().clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8) |
|
Image.fromarray(im).save(self.tmp_dp / f'{idx}-{i}.png') |
|
def rife_interp(i:int, j:int, k:int, alpha:float) -> Tensor: |
|
''' interp between i-th and j-th cond of the k-th ctrlnet set ''' |
|
fp0 = self.tmp_dp / f'{i}-{k}.png' |
|
fp1 = self.tmp_dp / f'{j}-{k}.png' |
|
fpo = self.tmp_dp / f'{i}-{j}-{alpha:.3f}.png' if self.dbg_rife else self.tmp_fp |
|
assert run_cmd(f'rife-ncnn-vulkan -m rife-v4 -s {alpha:.3f} -0 "{fp0}" -1 "{fp1}" -o "{fpo}"') |
|
x = torch.from_numpy(np.asarray(Image.open(fpo)) / 255.0) |
|
if len(x.shape) == 2: x = x.unsqueeze_(0) |
|
elif len(x.shape) == 3: x = x.permute([2, 0, 1]) |
|
else: raise ValueError(f'unknown cond shape: {x.shape}') |
|
x = x.unsqueeze(dim=0) |
|
return x |
|
''' βββ rife interp utils βββ ''' |
|
|
|
''' βββ filename reorder utils βββ ''' |
|
iframe = 0 |
|
def rename_image_filename(idx:int, param: ImageSaveParams): |
|
fn = param.filename |
|
stem, suffix = os.path.splitext(os.path.basename(fn)) |
|
param.filename = os.path.join(os.path.dirname(fn), f'{idx:05d}' + suffix) |
|
class on_before_image_saved_wrapper: |
|
def __init__(self, callback_fn): |
|
self.callback_fn = callback_fn |
|
def __enter__(self): |
|
on_before_image_saved(self.callback_fn) |
|
def __exit__(self, exc_type, exc_value, exc_traceback): |
|
remove_callbacks_for_function(self.callback_fn) |
|
''' βββ filename reorder utils βββ ''' |
|
|
|
|
|
setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[0])]) |
|
interp_alpha = 0.0 |
|
with on_before_image_saved_wrapper(partial(rename_image_filename, 0)): |
|
process_p() |
|
iframe += 1 |
|
save_ctrl_cond(0) |
|
|
|
|
|
for i in range(1, self.n_stages): |
|
if state.interrupted: break |
|
|
|
|
|
from_hint_cond = [t for t in to_hint_cond] ; to_hint_cond .clear() |
|
from_control_tensors = [t for t in to_control_tensors] ; to_control_tensors.clear() |
|
setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[i])]) |
|
interp_alpha = 0.0 |
|
|
|
with on_before_image_saved_wrapper(partial(rename_image_filename, iframe + self.steps[i])): |
|
cached_images = process_p(append=False) |
|
save_ctrl_cond(i) |
|
|
|
|
|
is_interrupted = False |
|
n_inter = self.steps[i] + 1 |
|
for t in range(1, n_inter): |
|
if state.interrupted: is_interrupted = True ; break |
|
|
|
interp_alpha = t / n_inter |
|
|
|
mid_hint_cond.clear() |
|
device = devices.get_device_for("controlnet") |
|
if self.interp_meth == InterpMethod.LINEAR: |
|
for hintA, hintB in zip(from_hint_cond, to_hint_cond): |
|
hintC = weighted_sum(hintA.to(device), hintB.to(device), interp_alpha) |
|
mid_hint_cond.append(hintC) |
|
elif self.interp_meth == InterpMethod.RIFE: |
|
dtype = to_hint_cond[0].dtype |
|
for k in range(len(to_hint_cond)): |
|
hintC = rife_interp(i-1, i, k, interp_alpha).to(device, dtype) |
|
mid_hint_cond.append(hintC) |
|
else: raise ValueError(f'unknown interp_meth: {self.interp_meth}') |
|
|
|
interp_ip = 0 |
|
with on_before_image_saved_wrapper(partial(rename_image_filename, iframe)): |
|
process_p() |
|
iframe += 1 |
|
|
|
|
|
images.extend(cached_images) |
|
iframe += 1 |
|
|
|
if is_interrupted: break |
|
|
|
return images, info |
|
|