# This extension works with [https://github.com/AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) # version: v1.4.0 LOG_PREFIX = '[Prompt-Travel]' import os from pathlib import Path from PIL.Image import Image as PILImage from enum import Enum from dataclasses import dataclass from functools import partial from typing import List, Tuple, Callable, Any, Optional, Generic, TypeVar from traceback import print_exc, format_exc import gradio as gr import numpy as np import torch from torch import Tensor import torch.nn.functional as F try: from moviepy.video.io.ImageSequenceClip import ImageSequenceClip from moviepy.editor import concatenate_videoclips, ImageClip except ImportError: print(f'{LOG_PREFIX} package moviepy not installed, will not be able to generate video') import modules.scripts as scripts from modules.script_callbacks import on_before_image_saved, ImageSaveParams, on_cfg_denoiser, CFGDenoiserParams, remove_callbacks_for_function from modules.ui import gr_show from modules.shared import state, opts, sd_upscalers from modules.processing import process_images, get_fixed_seed from modules.processing import Processed, StableDiffusionProcessing as Processing, StableDiffusionProcessingTxt2Img as ProcessingTxt2Img, StableDiffusionProcessingImg2Img as ProcessingImg2Img from modules.images import resize_image from modules.sd_samplers_common import single_sample_to_image class Mode(Enum): LINEAR = 'linear' REPLACE = 'replace' class LerpMethod(Enum): LERP = 'lerp' SLERP = 'slerp' class ModeReplaceDim(Enum): TOKEN = 'token' CHANNEL = 'channel' RANDOM = 'random' class ModeReplaceOrder(Enum): SIMILAR = 'similar' DIFFERENT = 'different' RANDOM = 'random' class Gensis(Enum): FIXED = 'fixed' SUCCESSIVE = 'successive' EMBRYO = 'embryo' class VideoFormat(Enum): MP4 = 'mp4' GIF = 'gif' WEBM = 'webm' if 'typing': T = TypeVar('T') @dataclass class Ref(Generic[T]): value: T = None TensorRef = Ref[Tensor] StrRef = Ref[str] PILImages = List[PILImage] RunResults = Tuple[PILImages, str] if 'consts': __ = lambda key, value=None: opts.data.get(f'customscript/prompt_travel.py/txt2img/{key}/value', value) LABEL_MODE = 'Travel mode' LABEL_STEPS = 'Travel steps between stages' LABEL_GENESIS = 'Frame genesis' LABEL_DENOISE_W = 'Denoise strength' LABEL_EMBRYO_STEP = 'Denoise steps for embryo' LABEL_LERP_METH = 'Linear interp method' LABEL_REPLACE_DIM = 'Replace dimension' LABEL_REPLACE_ORDER = 'Replace order' LABEL_VIDEO = 'Ext. export video' LABEL_VIDEO_FPS = 'Video FPS' LABEL_VIDEO_FMT = 'Video file format' LABEL_VIDEO_PAD = 'Pad begin/end frames' LABEL_VIDEO_PICK = 'Pick frame by slice' LABEL_UPSCALE = 'Ext. upscale' LABEL_UPSCALE_METH = 'Upscaler' LABEL_UPSCALE_RATIO = 'Upscale ratio' LABEL_UPSCALE_WIDTH = 'Upscale width' LABEL_UPSCALE_HEIGHT = 'Upscale height' LABEL_DEPTH = 'Ext. depth-image-io (for depth2img models)' LABEL_DEPTH_IMG = 'Depth image file' DEFAULT_MODE = __(LABEL_MODE, Mode.LINEAR.value) DEFAULT_STEPS = __(LABEL_STEPS, 30) DEFAULT_GENESIS = __(LABEL_GENESIS, Gensis.FIXED.value) DEFAULT_DENOISE_W = __(LABEL_DENOISE_W, 1.0) DEFAULT_EMBRYO_STEP = __(LABEL_EMBRYO_STEP, 8) DEFAULT_LERP_METH = __(LABEL_LERP_METH, LerpMethod.LERP.value) DEFAULT_REPLACE_DIM = __(LABEL_REPLACE_DIM, ModeReplaceDim.TOKEN.value) DEFAULT_REPLACE_ORDER = __(LABEL_REPLACE_ORDER, ModeReplaceOrder.RANDOM.value) DEFAULT_UPSCALE = __(LABEL_UPSCALE, False) DEFAULT_UPSCALE_METH = __(LABEL_UPSCALE_METH, 'Lanczos') DEFAULT_UPSCALE_RATIO = __(LABEL_UPSCALE_RATIO, 2.0) DEFAULT_UPSCALE_WIDTH = __(LABEL_UPSCALE_WIDTH, 0) DEFAULT_UPSCALE_HEIGHT = __(LABEL_UPSCALE_HEIGHT, 0) DEFAULT_VIDEO = __(LABEL_VIDEO, True) DEFAULT_VIDEO_FPS = __(LABEL_VIDEO_FPS, 10) DEFAULT_VIDEO_FMT = __(LABEL_VIDEO_FMT, VideoFormat.MP4.value) DEFAULT_VIDEO_PAD = __(LABEL_VIDEO_PAD, 0) DEFAULT_VIDEO_PICK = __(LABEL_VIDEO_PICK, '') DEFAULT_DEPTH = __(LABEL_DEPTH, False) CHOICES_MODE = [x.value for x in Mode] CHOICES_LERP_METH = [x.value for x in LerpMethod] CHOICES_GENESIS = [x.value for x in Gensis] CHOICES_REPLACE_DIM = [x.value for x in ModeReplaceDim] CHOICES_REPLACE_ORDER = [x.value for x in ModeReplaceOrder] CHOICES_UPSCALER = [x.name for x in sd_upscalers] CHOICES_VIDEO_FMT = [x.value for x in VideoFormat] EPS = 1e-6 def cond_align(condA:Tensor, condB:Tensor) -> Tuple[Tensor, Tensor]: d = condA.shape[0] - condB.shape[0] if d < 0: condA = F.pad(condA, (0, 0, 0, -d)) elif d > 0: condB = F.pad(condB, (0, 0, 0, d)) return condA, condB def wrap_cond_align(fn:Callable[..., Tensor]): def wrapper(condA:Tensor, condB:Tensor, *args, **kwargs) -> Tensor: condA, condB = cond_align(condA, condB) return fn(condA, condB, *args, **kwargs) return wrapper @wrap_cond_align def weighted_sum(condA:Tensor, condB:Tensor, alpha:float) -> Tensor: ''' linear interpolate on latent space of condition ''' return (1 - alpha) * condA + (alpha) * condB @wrap_cond_align def geometric_slerp(condA:Tensor, condB:Tensor, alpha:float) -> Tensor: ''' spherical linear interpolation on latent space of condition, ref: https://en.wikipedia.org/wiki/Slerp ''' A_n = condA / torch.norm(condA, dim=-1, keepdim=True) # [T=77, D=768] B_n = condB / torch.norm(condB, dim=-1, keepdim=True) dot = (A_n * B_n).sum(dim=-1, keepdim=True) # [T=77, D=1] omega = torch.acos(dot) # [T=77, D=1] so = torch.sin(omega) # [T=77, D=1] slerp = (torch.sin((1 - alpha) * omega) / so) * condA + (torch.sin(alpha * omega) / so) * condB mask: Tensor = dot > 0.9995 # [T=77, D=1] if not mask.any(): return slerp else: lerp = (1 - alpha) * condA + (alpha) * condB return torch.where(mask, lerp, slerp) # use simple lerp when angle very close to avoid NaN @wrap_cond_align def replace_until_match(condA:Tensor, condB:Tensor, count:int, dist:Tensor, order:str=ModeReplaceOrder.RANDOM) -> Tensor: ''' value substite on condition tensor; will inplace modify `dist` ''' def index_tensor_to_tuple(index:Tensor) -> Tuple[Tensor, ...]: return tuple([index[..., i] for i in range(index.shape[-1])]) # tuple([nDiff], ...) # mask: [T=77, D=768], [T=77] or [D=768] mask = dist > EPS # idx_diff: [nDiff, nDim=2] or [nDiff, nDim=1] idx_diff = torch.nonzero(mask) n_diff = len(idx_diff) if order == ModeReplaceOrder.RANDOM: sel = np.random.choice(range(n_diff), size=count, replace=False) if n_diff > count else slice(None) else: val_diff = dist[index_tensor_to_tuple(idx_diff)] # [nDiff] if order == ModeReplaceOrder.SIMILAR: sorted_index = val_diff.argsort() elif order == ModeReplaceOrder.DIFFERENT: sorted_index = val_diff.argsort(descending=True) else: raise ValueError(f'unknown replace_order: {order}') sel = sorted_index[:count] idx_diff_sel = idx_diff[sel, ...] # [cnt] => [cnt, nDim] idx_diff_sel_tp = index_tensor_to_tuple(idx_diff_sel) dist[idx_diff_sel_tp] = 0.0 mask[idx_diff_sel_tp] = False if mask.shape != condA.shape: # cond.shape = [T=77, D=768] mask_len = mask.shape[0] if mask_len == condA.shape[0]: mask = mask.unsqueeze(1) elif mask_len == condA.shape[1]: mask = mask.unsqueeze(0) else: raise ValueError(f'unknown mask.shape: {mask.shape}') mask = mask.expand_as(condA) return mask * condA + ~mask * condB def get_next_sequence_number(path:str) -> int: """ Determines and returns the next sequence number to use when saving an image in the specified directory. The sequence starts at 0. """ result = -1 dir = Path(path) for file in dir.iterdir(): if not file.is_dir(): continue try: num = int(file.name) if num > result: result = num except ValueError: pass return result + 1 def update_img2img_p(p:Processing, imgs:PILImages, denoising_strength:float=0.75) -> ProcessingImg2Img: if isinstance(p, ProcessingImg2Img): p.init_images = imgs p.denoising_strength = denoising_strength return p if isinstance(p, ProcessingTxt2Img): KNOWN_KEYS = [ # see `StableDiffusionProcessing.__init__()` 'sd_model', 'outpath_samples', 'outpath_grids', 'prompt', 'styles', 'seed', 'subseed', 'subseed_strength', 'seed_resize_from_h', 'seed_resize_from_w', 'seed_enable_extras', 'sampler_name', 'batch_size', 'n_iter', 'steps', 'cfg_scale', 'width', 'height', 'restore_faces', 'tiling', 'do_not_save_samples', 'do_not_save_grid', 'extra_generation_params', 'overlay_images', 'negative_prompt', 'eta', 'do_not_reload_embeddings', #'denoising_strength', 'ddim_discretize', 's_min_uncond', 's_churn', 's_tmax', 's_tmin', 's_noise', 'override_settings', 'override_settings_restore_afterwards', 'sampler_index', 'script_args', ] kwargs = { k: getattr(p, k) for k in dir(p) if k in KNOWN_KEYS } # inherit params return ProcessingImg2Img( init_images=imgs, denoising_strength=denoising_strength, **kwargs, ) def parse_slice(picker:str) -> Optional[slice]: if not picker.strip(): return None to_int = lambda s: None if not s else int(s) segs = [to_int(x.strip()) for x in picker.strip().split(':')] start, stop, step = None, None, None if len(segs) == 1: stop, = segs elif len(segs) == 2: start, stop = segs elif len(segs) == 3: start, stop, step = segs else: raise ValueError return slice(start, stop, step) def parse_resolution(width:int, height:int, upscale_ratio:float, upscale_width:int, upscale_height:int) -> Tuple[bool, Tuple[int, int]]: if upscale_width == upscale_height == 0: if upscale_ratio == 1.0: return False, (width, height) else: return True, (round(width * upscale_ratio), round(height * upscale_ratio)) else: if upscale_width == 0: upscale_width = round(width * upscale_height / height) if upscale_height == 0: upscale_height = round(height * upscale_width / width) return (width != upscale_width and height != upscale_height), (upscale_width, upscale_height) def upscale_image(img:PILImage, width:int, height:int, upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int) -> PILImage: if upscale_meth == 'None': return img need_upscale, (tgt_w, tgt_h) = parse_resolution(width, height, upscale_ratio, upscale_width, upscale_height) if need_upscale: if 'show_debug': print(f'>> upscale: ({width}, {height}) => ({tgt_w}, {tgt_h})') if max(tgt_w / width, tgt_h / height) > 4: # must split into two rounds for NN model capatibility hf_w, hf_h = round(width * 4), round(height * 4) img = resize_image(0, img, hf_w, hf_h, upscaler_name=upscale_meth) img = resize_image(0, img, tgt_w, tgt_h, upscaler_name=upscale_meth) return img def save_video(imgs:PILImages, video_slice:slice, video_pad:int, video_fps:float, video_fmt:VideoFormat, fbase:str): if len(imgs) <= 1 or 'ImageSequenceClip' not in globals(): return try: # arrange frames if video_slice: imgs = imgs[video_slice] if video_pad > 0: imgs = [imgs[0]] * video_pad + imgs + [imgs[-1]] * video_pad # export video seq: List[np.ndarray] = [np.asarray(img) for img in imgs] try: clip = ImageSequenceClip(seq, fps=video_fps) except: # images may have different size (small probability due to upscaler) clip = concatenate_videoclips([ImageClip(img, duration=1/video_fps) for img in seq], method='compose') clip.fps = video_fps if video_fmt == VideoFormat.MP4: clip.write_videofile(fbase + '.mp4', verbose=False, audio=False) elif video_fmt == VideoFormat.WEBM: clip.write_videofile(fbase + '.webm', verbose=False, audio=False) elif video_fmt == VideoFormat.GIF: clip.write_gif (fbase + '.gif', loop=True) except: print_exc() class on_cfg_denoiser_wrapper: def __init__(self, callback_fn:Callable): self.callback_fn = callback_fn def __enter__(self): on_cfg_denoiser(self.callback_fn) def __exit__(self, exc_type, exc_value, exc_traceback): remove_callbacks_for_function(self.callback_fn) class p_steps_overrider: def __init__(self, p:Processing, steps:int=1): self.p = p self.steps = steps self.steps_saved = self.p.steps def __enter__(self): self.p.steps = self.steps def __exit__(self, exc_type, exc_value, exc_traceback): self.p.steps = self.steps_saved class p_save_samples_overrider: def __init__(self, p:Processing, save:bool=True): self.p = p self.save = save self.do_not_save_samples_saved = self.p.do_not_save_samples def __enter__(self): self.p.do_not_save_samples = not self.save def __exit__(self, exc_type, exc_value, exc_traceback): self.p.do_not_save_samples = self.do_not_save_samples_saved def get_cond_callback(refs:List[TensorRef], params:CFGDenoiserParams): if params.sampling_step > 0: return values = [ params.text_cond, # [B=1, L= 77, D=768] params.text_uncond, # [B=1, L=231, D=768] ] for i, ref in enumerate(refs): ref.value = values[i] def set_cond_callback(refs:List[TensorRef], params:CFGDenoiserParams): values = [ params.text_cond, # [B=1, L= 77, D=768] params.text_uncond, # [B=1, L=231, D=768] ] for i, ref in enumerate(refs): values[i].data = ref.value def get_latent_callback(ref:TensorRef, embryo_step:int, params:CFGDenoiserParams): if params.sampling_step != embryo_step: return ref.value = params.x def set_latent_callback(ref:TensorRef, embryo_step:int, params:CFGDenoiserParams): if params.sampling_step != embryo_step: return params.x.data = ref.value def switch_to_stage_binding_(self:'Script', i:int): if 'show_debug': print(f'[stage {i+1}/{self.n_stages}]') print(f' pos prompt: {self.pos_prompts[i]}') if hasattr(self, 'neg_prompts'): print(f' neg prompt: {self.neg_prompts[i]}') self.p.prompt = self.pos_prompts[i] if hasattr(self, 'neg_prompts'): self.p.negative_prompt = self.neg_prompts[i] self.p.subseed = self.subseed def process_p_binding_(self:'Script', append:bool=True, save:bool=True) -> PILImages: assert hasattr(self, 'images') and hasattr(self, 'info'), 'unknown logic, "images" and "info" not initialized' with p_save_samples_overrider(self.p, save): proc = process_images(self.p) if save: if not self.info.value: self.info.value = proc.info if append: self.images.extend(proc.images) if self.genesis == Gensis.SUCCESSIVE: self.p = update_img2img_p(self.p, self.images[-1:], self.denoise_w) return proc.images class Script(scripts.Script): def title(self): return 'Prompt Travel' def describe(self): return 'Travel from one prompt to another in the text encoder latent space.' def show(self, is_img2img): return True def ui(self, is_img2img): with gr.Row(variant='compact') as tab_mode: mode = gr.Radio (label=LABEL_MODE, value=lambda: DEFAULT_MODE, choices=CHOICES_MODE) lerp_meth = gr.Dropdown(label=LABEL_LERP_METH, value=lambda: DEFAULT_LERP_METH, choices=CHOICES_LERP_METH) replace_dim = gr.Dropdown(label=LABEL_REPLACE_DIM, value=lambda: DEFAULT_REPLACE_DIM, choices=CHOICES_REPLACE_DIM, visible=False) replace_order = gr.Dropdown(label=LABEL_REPLACE_ORDER, value=lambda: DEFAULT_REPLACE_ORDER, choices=CHOICES_REPLACE_ORDER, visible=False) def switch_mode(mode:str): show_meth = Mode(mode) == Mode.LINEAR show_repl = Mode(mode) == Mode.REPLACE return [gr_show(x) for x in [show_meth, show_repl, show_repl]] mode.change(switch_mode, inputs=[mode], outputs=[lerp_meth, replace_dim, replace_order], show_progress=False) with gr.Row(variant='compact') as tab_param: steps = gr.Text (label=LABEL_STEPS, value=lambda: DEFAULT_STEPS, max_lines=1) genesis = gr.Dropdown(label=LABEL_GENESIS, value=lambda: DEFAULT_GENESIS, choices=CHOICES_GENESIS) denoise_w = gr.Slider (label=LABEL_DENOISE_W, value=lambda: DEFAULT_DENOISE_W, minimum=0.0, maximum=1.0, visible=False) embryo_step = gr.Text (label=LABEL_EMBRYO_STEP, value=lambda: DEFAULT_EMBRYO_STEP, max_lines=1, visible=False) def switch_genesis(genesis:str): show_dw = Gensis(genesis) == Gensis.SUCCESSIVE # show 'denoise_w' for 'successive' show_es = Gensis(genesis) == Gensis.EMBRYO # show 'embryo_step' for 'embryo' return [gr_show(x) for x in [show_dw, show_es]] genesis.change(switch_genesis, inputs=[genesis], outputs=[denoise_w, embryo_step], show_progress=False) with gr.Row(variant='compact', visible=DEFAULT_DEPTH) as tab_ext_depth: depth_img = gr.Image(label=LABEL_DEPTH_IMG, source='upload', type='pil', image_mode=None) with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale: upscale_meth = gr.Dropdown(label=LABEL_UPSCALE_METH, value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICES_UPSCALER) upscale_ratio = gr.Slider (label=LABEL_UPSCALE_RATIO, value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=16.0, step=0.1) upscale_width = gr.Slider (label=LABEL_UPSCALE_WIDTH, value=lambda: DEFAULT_UPSCALE_WIDTH, minimum=0, maximum=2048, step=8) upscale_height = gr.Slider (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0, maximum=2048, step=8) with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video: video_fmt = gr.Dropdown(label=LABEL_VIDEO_FMT, value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT) video_fps = gr.Number (label=LABEL_VIDEO_FPS, value=lambda: DEFAULT_VIDEO_FPS) video_pad = gr.Number (label=LABEL_VIDEO_PAD, value=lambda: DEFAULT_VIDEO_PAD, precision=0) video_pick = gr.Text (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1) with gr.Row(variant='compact') as tab_ext: ext_video = gr.Checkbox(label=LABEL_VIDEO, value=lambda: DEFAULT_VIDEO) ext_upscale = gr.Checkbox(label=LABEL_UPSCALE, value=lambda: DEFAULT_UPSCALE) ext_depth = gr.Checkbox(label=LABEL_DEPTH, value=lambda: DEFAULT_DEPTH) ext_video .change(gr_show, inputs=ext_video, outputs=tab_ext_video, show_progress=False) ext_upscale.change(gr_show, inputs=ext_upscale, outputs=tab_ext_upscale, show_progress=False) ext_depth .change(gr_show, inputs=ext_depth, outputs=tab_ext_depth, show_progress=False) return [ mode, lerp_meth, replace_dim, replace_order, steps, genesis, denoise_w, embryo_step, depth_img, upscale_meth, upscale_ratio, upscale_width, upscale_height, video_fmt, video_fps, video_pad, video_pick, ext_video, ext_upscale, ext_depth, ] def run(self, p:Processing, mode:str, lerp_meth:str, replace_dim:str, replace_order:str, steps:str, genesis:str, denoise_w:float, embryo_step:str, depth_img:PILImage, upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int, video_fmt:str, video_fps:float, video_pad:int, video_pick:str, ext_video:bool, ext_upscale:bool, ext_depth:bool, ): # enum lookup mode: Mode = Mode(mode) lerp_meth: LerpMethod = LerpMethod(lerp_meth) replace_dim: ModeReplaceDim = ModeReplaceDim(replace_dim) replace_order: ModeReplaceOrder = ModeReplaceOrder(replace_order) genesis: Gensis = Gensis(genesis) video_fmt: VideoFormat = VideoFormat(video_fmt) # Param check & type convert if ext_video: if video_pad < 0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}') if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}') try: video_slice = parse_slice(video_pick) except: return Processed(p, [], p.seed, 'syntax error in video_slice') if genesis == Gensis.EMBRYO: try: x = float(embryo_step) except: return Processed(p, [], p.seed, f'embryo_step is not a number: {embryo_step}') if x <= 0: return Processed(p, [], p.seed, f'embryo_step must > 0, but got {embryo_step}') embryo_step: int = round(x * p.steps if x < 1.0 else x) ; del x # Prepare prompts & steps prompt_pos = p.prompt.strip() if not prompt_pos: return Processed(p, [], p.seed, 'positive prompt should not be empty :(') pos_prompts = [p.strip() for p in prompt_pos.split('\n') if p.strip()] if len(pos_prompts) == 1: return Processed(p, [], p.seed, 'should specify at least two lines of prompt to travel between :(') if genesis == Gensis.EMBRYO and len(pos_prompts) > 2: return Processed(p, [], p.seed, 'processing with "embryo" genesis takes exactly two lines of prompt :(') prompt_neg = p.negative_prompt.strip() neg_prompts = [p.strip() for p in prompt_neg.split('\n') if p.strip()] if len(neg_prompts) == 0: neg_prompts = [''] n_stages = max(len(pos_prompts), len(neg_prompts)) while len(pos_prompts) < n_stages: pos_prompts.append(pos_prompts[-1]) while len(neg_prompts) < n_stages: neg_prompts.append(neg_prompts[-1]) try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')] except: return Processed(p, [], p.seed, f'cannot parse steps option: {steps}') if len(steps) == 1: steps = [steps[0]] * (n_stages - 1) elif len(steps) != n_stages - 1: return Processed(p, [], p.seed, f'stage count mismatch: you have {n_stages} prompt stages, but specified {len(steps)} steps; should assure len(steps) == len(stages) - 1') n_frames = sum(steps) + n_stages if 'show_debug': print('n_stages:', n_stages) print('n_frames:', n_frames) print('steps:', steps) steps.insert(0, -1) # fixup the first stage # Custom saving path travel_path = os.path.join(p.outpath_samples, 'prompt_travel') os.makedirs(travel_path, exist_ok=True) travel_number = get_next_sequence_number(travel_path) self.log_dp = os.path.join(travel_path, f'{travel_number:05}') p.outpath_samples = self.log_dp os.makedirs(self.log_dp, exist_ok=True) #self.log_fp = os.path.join(self.log_dp, 'log.txt') # Force batch count and batch size to 1 p.n_iter = 1 p.batch_size = 1 # Random unified const seed p.seed = get_fixed_seed(p.seed) # fix it to assure all processes using the same major seed self.subseed = p.subseed # stash it to allow using random subseed for each process (when -1) if 'show_debug': print('seed:', p.seed) print('subseed:', p.subseed) print('subseed_strength:', p.subseed_strength) # Start job state.job_count = n_frames # Pack parameters self.pos_prompts = pos_prompts self.neg_prompts = neg_prompts self.steps = steps self.genesis = genesis self.denoise_w = denoise_w self.embryo_step = embryo_step self.lerp_meth = lerp_meth self.replace_dim = replace_dim self.replace_order = replace_order self.n_stages = n_stages self.n_frames = n_frames def upscale_image_callback(params:ImageSaveParams): params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height) # Dispatch self.p: Processing = p self.images: PILImages = [] self.info: StrRef = Ref() try: if ext_upscale: on_before_image_saved(upscale_image_callback) if ext_depth: self.ext_depth_preprocess(p, depth_img) runner = getattr(self, f'run_{mode.value}') if not runner: return Processed(p, [], p.seed, f'no runner found for mode: {mode.value}') runner() except: e = format_exc() print(e) self.info.value = e finally: if ext_depth: self.ext_depth_postprocess(p, depth_img) if ext_upscale: remove_callbacks_for_function(upscale_image_callback) # Save video if ext_video: save_video(self.images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}')) return Processed(p, self.images, p.seed, self.info.value) def run_linear(self): # dispatch for special case if self.genesis == Gensis.EMBRYO: return self.run_linear_embryo() lerp_fn = weighted_sum if self.lerp_meth == LerpMethod.LERP else geometric_slerp if 'auxiliary': switch_to_stage = partial(switch_to_stage_binding_, self) process_p = partial(process_p_binding_, self) from_pos_hidden: TensorRef = Ref() from_neg_hidden: TensorRef = Ref() to_pos_hidden: TensorRef = Ref() to_neg_hidden: TensorRef = Ref() inter_pos_hidden: TensorRef = Ref() inter_neg_hidden: TensorRef = Ref() # Step 1: draw the init image switch_to_stage(0) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden, from_neg_hidden])): process_p() # travel through stages for i in range(1, self.n_stages): if state.interrupted: break state.job = f'{i}/{self.n_frames}' state.job_no = i + 1 # only change target prompts switch_to_stage(i) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden, to_neg_hidden])): if self.genesis == Gensis.FIXED: imgs = process_p(append=False) # stash it to make order right elif self.genesis == Gensis.SUCCESSIVE: with p_steps_overrider(self.p, steps=1): # ignore final image, only need cond process_p(save=False, append=False) else: raise ValueError(f'invalid genesis: {self.genesis.value}') # Step 2: draw the interpolated images is_break_iter = False n_inter = self.steps[i] for t in range(1, n_inter + (1 if self.genesis == Gensis.SUCCESSIVE else 0)): if state.interrupted: is_break_iter = True ; break alpha = t / n_inter # [1/T, 2/T, .. T-1/T] (+ [T/T])? inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, alpha) inter_neg_hidden.value = lerp_fn(from_neg_hidden.value, to_neg_hidden.value, alpha) with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden, inter_neg_hidden])): process_p() if is_break_iter: break # Step 3: append the final stage if self.genesis != Gensis.SUCCESSIVE: self.images.extend(imgs) # move to next stage from_pos_hidden.value, from_neg_hidden.value = to_pos_hidden.value, to_neg_hidden.value inter_pos_hidden.value, inter_neg_hidden.value = None, None def run_linear_embryo(self): ''' NOTE: this procedure has special logic, we separate it from run_linear() so far ''' lerp_fn = weighted_sum if self.lerp_meth == LerpMethod.LERP else geometric_slerp n_frames = self.steps[1] + 2 if 'auxiliary': switch_to_stage = partial(switch_to_stage_binding_, self) process_p = partial(process_p_binding_, self) from_pos_hidden: TensorRef = Ref() to_pos_hidden: TensorRef = Ref() inter_pos_hidden: TensorRef = Ref() embryo: TensorRef = Ref() # latent image, the common half-denoised prototype of all frames # Step 1: get starting & ending condition switch_to_stage(0) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden])): with p_steps_overrider(self.p, steps=1): process_p(save=False) switch_to_stage(1) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden])): with p_steps_overrider(self.p, steps=1): process_p(save=False) # Step 2: get the condition middle-point as embryo then hatch it halfway inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, 0.5) with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])): with on_cfg_denoiser_wrapper(partial(get_latent_callback, embryo, self.embryo_step)): process_p(save=False) try: img: PILImage = single_sample_to_image(embryo.value[0], approximation=-1) # the data is duplicated, just get first item img.save(os.path.join(self.log_dp, 'embryo.png')) except: pass # Step 3: derive the embryo towards each interpolated condition for t in range(0, n_frames+1): if state.interrupted: break alpha = t / n_frames # [0, 1/T, 2/T, .. T-1/T, 1] inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, alpha) with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])): with on_cfg_denoiser_wrapper(partial(set_latent_callback, embryo, self.embryo_step)): process_p() def run_replace(self): ''' yet another replace method, but do replacing on the condition tensor by token dim or channel dim ''' if self.genesis == Gensis.EMBRYO: raise NotImplementedError(f'genesis {self.genesis.value!r} is only supported in linear mode currently :(') if 'auxiliary': switch_to_stage = partial(switch_to_stage_binding_, self) process_p = partial(process_p_binding_, self) from_pos_hidden: TensorRef = Ref() to_pos_hidden: TensorRef = Ref() inter_pos_hidden: TensorRef = Ref() # Step 1: draw the init image switch_to_stage(0) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden])): process_p() # travel through stages for i in range(1, self.n_stages): if state.interrupted: break state.job = f'{i}/{self.n_frames}' state.job_no = i + 1 # only change target prompts switch_to_stage(i) with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden])): if self.genesis == Gensis.FIXED: imgs = process_p(append=False) # stash it to make order right elif self.genesis == Gensis.SUCCESSIVE: with p_steps_overrider(self.p, steps=1): # ignore final image, only need cond process_p(save=False, append=False) else: raise ValueError(f'invalid genesis: {self.genesis.value}') # ========== ↓↓↓ major differences from run_linear() ↓↓↓ ========== # decide change portion in each iter L1 = torch.abs(from_pos_hidden.value - to_pos_hidden.value) if self.replace_dim == ModeReplaceDim.RANDOM: dist = L1 # [T=77, D=768] elif self.replace_dim == ModeReplaceDim.TOKEN: dist = L1.mean(axis=1) # [T=77] elif self.replace_dim == ModeReplaceDim.CHANNEL: dist = L1.mean(axis=0) # [D=768] else: raise ValueError(f'unknown replace_dim: {self.replace_dim}') mask = dist > EPS dist = torch.where(mask, dist, 0.0) n_diff = mask.sum().item() # when value differs we have mask==True n_inter = self.steps[i] + 1 replace_count = int(n_diff / n_inter) + 1 # => accumulative modifies [1/T, 2/T, .. T-1/T] of total cond # Step 2: draw the replaced images inter_pos_hidden.value = from_pos_hidden.value is_break_iter = False for _ in range(1, n_inter): if state.interrupted: is_break_iter = True ; break inter_pos_hidden.value = replace_until_match(inter_pos_hidden.value, to_pos_hidden.value, replace_count, dist=dist, order=self.replace_order) with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])): process_p() # ========== ↑↑↑ major differences from run_linear() ↑↑↑ ========== if is_break_iter: break # Step 3: append the final stage if self.genesis != Gensis.SUCCESSIVE: self.images.extend(imgs) # move to next stage from_pos_hidden.value = to_pos_hidden.value inter_pos_hidden.value = None ''' ↓↓↓ extension support ↓↓↓ ''' def ext_depth_preprocess(self, p:Processing, depth_img:PILImage): # copy from repo `AnonymousCervine/depth-image-io-for-SDWebui` from types import MethodType from einops import repeat, rearrange import modules.shared as shared import modules.devices as devices def sanitize_pil_image_mode(img): if img.mode in {'P', 'CMYK', 'HSV'}: img = img.convert(mode='RGB') return img def alt_depth_image_conditioning(self, source_image): with devices.autocast(): conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) depth_data = np.array(sanitize_pil_image_mode(depth_img)) if len(np.shape(depth_data)) == 2: depth_data = rearrange(depth_data, "h w -> 1 1 h w") else: depth_data = rearrange(depth_data, "h w c -> c 1 1 h w")[0] depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32) depth_data = repeat(depth_data, "1 ... -> n ...", n=self.batch_size) conditioning = torch.nn.functional.interpolate( depth_data, size=conditioning_image.shape[2:], mode="bicubic", align_corners=False, ) (depth_min, depth_max) = torch.aminmax(conditioning) conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. return conditioning p.depth2img_image_conditioning = MethodType(alt_depth_image_conditioning, p) def alt_txt2img_image_conditioning(self, x, width=None, height=None): fake_img = torch.zeros(1, 3, height or self.height, width or self.width).to(shared.device).type(self.sd_model.dtype) return self.depth2img_image_conditioning(fake_img) p.txt2img_image_conditioning = MethodType(alt_txt2img_image_conditioning, p) def ext_depth_postprocess(self, p:Processing, depth_img:PILImage): depth_img.close()