File size: 37,321 Bytes
d9f3559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 |
# This extension works with [https://github.com/AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
# version: v1.4.0
LOG_PREFIX = '[Prompt-Travel]'
import os
from pathlib import Path
from PIL.Image import Image as PILImage
from enum import Enum
from dataclasses import dataclass
from functools import partial
from typing import List, Tuple, Callable, Any, Optional, Generic, TypeVar
from traceback import print_exc, format_exc
import gradio as gr
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from moviepy.editor import concatenate_videoclips, ImageClip
except ImportError:
print(f'{LOG_PREFIX} package moviepy not installed, will not be able to generate video')
import modules.scripts as scripts
from modules.script_callbacks import on_before_image_saved, ImageSaveParams, on_cfg_denoiser, CFGDenoiserParams, remove_callbacks_for_function
from modules.ui import gr_show
from modules.shared import state, opts, sd_upscalers
from modules.processing import process_images, get_fixed_seed
from modules.processing import Processed, StableDiffusionProcessing as Processing, StableDiffusionProcessingTxt2Img as ProcessingTxt2Img, StableDiffusionProcessingImg2Img as ProcessingImg2Img
from modules.images import resize_image
from modules.sd_samplers_common import single_sample_to_image
class Mode(Enum):
LINEAR = 'linear'
REPLACE = 'replace'
class LerpMethod(Enum):
LERP = 'lerp'
SLERP = 'slerp'
class ModeReplaceDim(Enum):
TOKEN = 'token'
CHANNEL = 'channel'
RANDOM = 'random'
class ModeReplaceOrder(Enum):
SIMILAR = 'similar'
DIFFERENT = 'different'
RANDOM = 'random'
class Gensis(Enum):
FIXED = 'fixed'
SUCCESSIVE = 'successive'
EMBRYO = 'embryo'
class VideoFormat(Enum):
MP4 = 'mp4'
GIF = 'gif'
WEBM = 'webm'
if 'typing':
T = TypeVar('T')
@dataclass
class Ref(Generic[T]): value: T = None
TensorRef = Ref[Tensor]
StrRef = Ref[str]
PILImages = List[PILImage]
RunResults = Tuple[PILImages, str]
if 'consts':
__ = lambda key, value=None: opts.data.get(f'customscript/prompt_travel.py/txt2img/{key}/value', value)
LABEL_MODE = 'Travel mode'
LABEL_STEPS = 'Travel steps between stages'
LABEL_GENESIS = 'Frame genesis'
LABEL_DENOISE_W = 'Denoise strength'
LABEL_EMBRYO_STEP = 'Denoise steps for embryo'
LABEL_LERP_METH = 'Linear interp method'
LABEL_REPLACE_DIM = 'Replace dimension'
LABEL_REPLACE_ORDER = 'Replace order'
LABEL_VIDEO = 'Ext. export video'
LABEL_VIDEO_FPS = 'Video FPS'
LABEL_VIDEO_FMT = 'Video file format'
LABEL_VIDEO_PAD = 'Pad begin/end frames'
LABEL_VIDEO_PICK = 'Pick frame by slice'
LABEL_UPSCALE = 'Ext. upscale'
LABEL_UPSCALE_METH = 'Upscaler'
LABEL_UPSCALE_RATIO = 'Upscale ratio'
LABEL_UPSCALE_WIDTH = 'Upscale width'
LABEL_UPSCALE_HEIGHT = 'Upscale height'
LABEL_DEPTH = 'Ext. depth-image-io (for depth2img models)'
LABEL_DEPTH_IMG = 'Depth image file'
DEFAULT_MODE = __(LABEL_MODE, Mode.LINEAR.value)
DEFAULT_STEPS = __(LABEL_STEPS, 30)
DEFAULT_GENESIS = __(LABEL_GENESIS, Gensis.FIXED.value)
DEFAULT_DENOISE_W = __(LABEL_DENOISE_W, 1.0)
DEFAULT_EMBRYO_STEP = __(LABEL_EMBRYO_STEP, 8)
DEFAULT_LERP_METH = __(LABEL_LERP_METH, LerpMethod.LERP.value)
DEFAULT_REPLACE_DIM = __(LABEL_REPLACE_DIM, ModeReplaceDim.TOKEN.value)
DEFAULT_REPLACE_ORDER = __(LABEL_REPLACE_ORDER, ModeReplaceOrder.RANDOM.value)
DEFAULT_UPSCALE = __(LABEL_UPSCALE, False)
DEFAULT_UPSCALE_METH = __(LABEL_UPSCALE_METH, 'Lanczos')
DEFAULT_UPSCALE_RATIO = __(LABEL_UPSCALE_RATIO, 2.0)
DEFAULT_UPSCALE_WIDTH = __(LABEL_UPSCALE_WIDTH, 0)
DEFAULT_UPSCALE_HEIGHT = __(LABEL_UPSCALE_HEIGHT, 0)
DEFAULT_VIDEO = __(LABEL_VIDEO, True)
DEFAULT_VIDEO_FPS = __(LABEL_VIDEO_FPS, 10)
DEFAULT_VIDEO_FMT = __(LABEL_VIDEO_FMT, VideoFormat.MP4.value)
DEFAULT_VIDEO_PAD = __(LABEL_VIDEO_PAD, 0)
DEFAULT_VIDEO_PICK = __(LABEL_VIDEO_PICK, '')
DEFAULT_DEPTH = __(LABEL_DEPTH, False)
CHOICES_MODE = [x.value for x in Mode]
CHOICES_LERP_METH = [x.value for x in LerpMethod]
CHOICES_GENESIS = [x.value for x in Gensis]
CHOICES_REPLACE_DIM = [x.value for x in ModeReplaceDim]
CHOICES_REPLACE_ORDER = [x.value for x in ModeReplaceOrder]
CHOICES_UPSCALER = [x.name for x in sd_upscalers]
CHOICES_VIDEO_FMT = [x.value for x in VideoFormat]
EPS = 1e-6
def cond_align(condA:Tensor, condB:Tensor) -> Tuple[Tensor, Tensor]:
d = condA.shape[0] - condB.shape[0]
if d < 0: condA = F.pad(condA, (0, 0, 0, -d))
elif d > 0: condB = F.pad(condB, (0, 0, 0, d))
return condA, condB
def wrap_cond_align(fn:Callable[..., Tensor]):
def wrapper(condA:Tensor, condB:Tensor, *args, **kwargs) -> Tensor:
condA, condB = cond_align(condA, condB)
return fn(condA, condB, *args, **kwargs)
return wrapper
@wrap_cond_align
def weighted_sum(condA:Tensor, condB:Tensor, alpha:float) -> Tensor:
''' linear interpolate on latent space of condition '''
return (1 - alpha) * condA + (alpha) * condB
@wrap_cond_align
def geometric_slerp(condA:Tensor, condB:Tensor, alpha:float) -> Tensor:
''' spherical linear interpolation on latent space of condition, ref: https://en.wikipedia.org/wiki/Slerp '''
A_n = condA / torch.norm(condA, dim=-1, keepdim=True) # [T=77, D=768]
B_n = condB / torch.norm(condB, dim=-1, keepdim=True)
dot = (A_n * B_n).sum(dim=-1, keepdim=True) # [T=77, D=1]
omega = torch.acos(dot) # [T=77, D=1]
so = torch.sin(omega) # [T=77, D=1]
slerp = (torch.sin((1 - alpha) * omega) / so) * condA + (torch.sin(alpha * omega) / so) * condB
mask: Tensor = dot > 0.9995 # [T=77, D=1]
if not mask.any():
return slerp
else:
lerp = (1 - alpha) * condA + (alpha) * condB
return torch.where(mask, lerp, slerp) # use simple lerp when angle very close to avoid NaN
@wrap_cond_align
def replace_until_match(condA:Tensor, condB:Tensor, count:int, dist:Tensor, order:str=ModeReplaceOrder.RANDOM) -> Tensor:
''' value substite on condition tensor; will inplace modify `dist` '''
def index_tensor_to_tuple(index:Tensor) -> Tuple[Tensor, ...]:
return tuple([index[..., i] for i in range(index.shape[-1])]) # tuple([nDiff], ...)
# mask: [T=77, D=768], [T=77] or [D=768]
mask = dist > EPS
# idx_diff: [nDiff, nDim=2] or [nDiff, nDim=1]
idx_diff = torch.nonzero(mask)
n_diff = len(idx_diff)
if order == ModeReplaceOrder.RANDOM:
sel = np.random.choice(range(n_diff), size=count, replace=False) if n_diff > count else slice(None)
else:
val_diff = dist[index_tensor_to_tuple(idx_diff)] # [nDiff]
if order == ModeReplaceOrder.SIMILAR:
sorted_index = val_diff.argsort()
elif order == ModeReplaceOrder.DIFFERENT:
sorted_index = val_diff.argsort(descending=True)
else: raise ValueError(f'unknown replace_order: {order}')
sel = sorted_index[:count]
idx_diff_sel = idx_diff[sel, ...] # [cnt] => [cnt, nDim]
idx_diff_sel_tp = index_tensor_to_tuple(idx_diff_sel)
dist[idx_diff_sel_tp] = 0.0
mask[idx_diff_sel_tp] = False
if mask.shape != condA.shape: # cond.shape = [T=77, D=768]
mask_len = mask.shape[0]
if mask_len == condA.shape[0]: mask = mask.unsqueeze(1)
elif mask_len == condA.shape[1]: mask = mask.unsqueeze(0)
else: raise ValueError(f'unknown mask.shape: {mask.shape}')
mask = mask.expand_as(condA)
return mask * condA + ~mask * condB
def get_next_sequence_number(path:str) -> int:
""" Determines and returns the next sequence number to use when saving an image in the specified directory. The sequence starts at 0. """
result = -1
dir = Path(path)
for file in dir.iterdir():
if not file.is_dir(): continue
try:
num = int(file.name)
if num > result: result = num
except ValueError:
pass
return result + 1
def update_img2img_p(p:Processing, imgs:PILImages, denoising_strength:float=0.75) -> ProcessingImg2Img:
if isinstance(p, ProcessingImg2Img):
p.init_images = imgs
p.denoising_strength = denoising_strength
return p
if isinstance(p, ProcessingTxt2Img):
KNOWN_KEYS = [ # see `StableDiffusionProcessing.__init__()`
'sd_model',
'outpath_samples',
'outpath_grids',
'prompt',
'styles',
'seed',
'subseed',
'subseed_strength',
'seed_resize_from_h',
'seed_resize_from_w',
'seed_enable_extras',
'sampler_name',
'batch_size',
'n_iter',
'steps',
'cfg_scale',
'width',
'height',
'restore_faces',
'tiling',
'do_not_save_samples',
'do_not_save_grid',
'extra_generation_params',
'overlay_images',
'negative_prompt',
'eta',
'do_not_reload_embeddings',
#'denoising_strength',
'ddim_discretize',
's_min_uncond',
's_churn',
's_tmax',
's_tmin',
's_noise',
'override_settings',
'override_settings_restore_afterwards',
'sampler_index',
'script_args',
]
kwargs = { k: getattr(p, k) for k in dir(p) if k in KNOWN_KEYS } # inherit params
return ProcessingImg2Img(
init_images=imgs,
denoising_strength=denoising_strength,
**kwargs,
)
def parse_slice(picker:str) -> Optional[slice]:
if not picker.strip(): return None
to_int = lambda s: None if not s else int(s)
segs = [to_int(x.strip()) for x in picker.strip().split(':')]
start, stop, step = None, None, None
if len(segs) == 1: stop, = segs
elif len(segs) == 2: start, stop = segs
elif len(segs) == 3: start, stop, step = segs
else: raise ValueError
return slice(start, stop, step)
def parse_resolution(width:int, height:int, upscale_ratio:float, upscale_width:int, upscale_height:int) -> Tuple[bool, Tuple[int, int]]:
if upscale_width == upscale_height == 0:
if upscale_ratio == 1.0:
return False, (width, height)
else:
return True, (round(width * upscale_ratio), round(height * upscale_ratio))
else:
if upscale_width == 0: upscale_width = round(width * upscale_height / height)
if upscale_height == 0: upscale_height = round(height * upscale_width / width)
return (width != upscale_width and height != upscale_height), (upscale_width, upscale_height)
def upscale_image(img:PILImage, width:int, height:int, upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int) -> PILImage:
if upscale_meth == 'None': return img
need_upscale, (tgt_w, tgt_h) = parse_resolution(width, height, upscale_ratio, upscale_width, upscale_height)
if need_upscale:
if 'show_debug': print(f'>> upscale: ({width}, {height}) => ({tgt_w}, {tgt_h})')
if max(tgt_w / width, tgt_h / height) > 4: # must split into two rounds for NN model capatibility
hf_w, hf_h = round(width * 4), round(height * 4)
img = resize_image(0, img, hf_w, hf_h, upscaler_name=upscale_meth)
img = resize_image(0, img, tgt_w, tgt_h, upscaler_name=upscale_meth)
return img
def save_video(imgs:PILImages, video_slice:slice, video_pad:int, video_fps:float, video_fmt:VideoFormat, fbase:str):
if len(imgs) <= 1 or 'ImageSequenceClip' not in globals(): return
try:
# arrange frames
if video_slice: imgs = imgs[video_slice]
if video_pad > 0: imgs = [imgs[0]] * video_pad + imgs + [imgs[-1]] * video_pad
# export video
seq: List[np.ndarray] = [np.asarray(img) for img in imgs]
try:
clip = ImageSequenceClip(seq, fps=video_fps)
except: # images may have different size (small probability due to upscaler)
clip = concatenate_videoclips([ImageClip(img, duration=1/video_fps) for img in seq], method='compose')
clip.fps = video_fps
if video_fmt == VideoFormat.MP4: clip.write_videofile(fbase + '.mp4', verbose=False, audio=False)
elif video_fmt == VideoFormat.WEBM: clip.write_videofile(fbase + '.webm', verbose=False, audio=False)
elif video_fmt == VideoFormat.GIF: clip.write_gif (fbase + '.gif', loop=True)
except: print_exc()
class on_cfg_denoiser_wrapper:
def __init__(self, callback_fn:Callable):
self.callback_fn = callback_fn
def __enter__(self):
on_cfg_denoiser(self.callback_fn)
def __exit__(self, exc_type, exc_value, exc_traceback):
remove_callbacks_for_function(self.callback_fn)
class p_steps_overrider:
def __init__(self, p:Processing, steps:int=1):
self.p = p
self.steps = steps
self.steps_saved = self.p.steps
def __enter__(self):
self.p.steps = self.steps
def __exit__(self, exc_type, exc_value, exc_traceback):
self.p.steps = self.steps_saved
class p_save_samples_overrider:
def __init__(self, p:Processing, save:bool=True):
self.p = p
self.save = save
self.do_not_save_samples_saved = self.p.do_not_save_samples
def __enter__(self):
self.p.do_not_save_samples = not self.save
def __exit__(self, exc_type, exc_value, exc_traceback):
self.p.do_not_save_samples = self.do_not_save_samples_saved
def get_cond_callback(refs:List[TensorRef], params:CFGDenoiserParams):
if params.sampling_step > 0: return
values = [
params.text_cond, # [B=1, L= 77, D=768]
params.text_uncond, # [B=1, L=231, D=768]
]
for i, ref in enumerate(refs):
ref.value = values[i]
def set_cond_callback(refs:List[TensorRef], params:CFGDenoiserParams):
values = [
params.text_cond, # [B=1, L= 77, D=768]
params.text_uncond, # [B=1, L=231, D=768]
]
for i, ref in enumerate(refs):
values[i].data = ref.value
def get_latent_callback(ref:TensorRef, embryo_step:int, params:CFGDenoiserParams):
if params.sampling_step != embryo_step: return
ref.value = params.x
def set_latent_callback(ref:TensorRef, embryo_step:int, params:CFGDenoiserParams):
if params.sampling_step != embryo_step: return
params.x.data = ref.value
def switch_to_stage_binding_(self:'Script', i:int):
if 'show_debug':
print(f'[stage {i+1}/{self.n_stages}]')
print(f' pos prompt: {self.pos_prompts[i]}')
if hasattr(self, 'neg_prompts'):
print(f' neg prompt: {self.neg_prompts[i]}')
self.p.prompt = self.pos_prompts[i]
if hasattr(self, 'neg_prompts'):
self.p.negative_prompt = self.neg_prompts[i]
self.p.subseed = self.subseed
def process_p_binding_(self:'Script', append:bool=True, save:bool=True) -> PILImages:
assert hasattr(self, 'images') and hasattr(self, 'info'), 'unknown logic, "images" and "info" not initialized'
with p_save_samples_overrider(self.p, save):
proc = process_images(self.p)
if save:
if not self.info.value: self.info.value = proc.info
if append: self.images.extend(proc.images)
if self.genesis == Gensis.SUCCESSIVE:
self.p = update_img2img_p(self.p, self.images[-1:], self.denoise_w)
return proc.images
class Script(scripts.Script):
def title(self):
return 'Prompt Travel'
def describe(self):
return 'Travel from one prompt to another in the text encoder latent space.'
def show(self, is_img2img):
return True
def ui(self, is_img2img):
with gr.Row(variant='compact') as tab_mode:
mode = gr.Radio (label=LABEL_MODE, value=lambda: DEFAULT_MODE, choices=CHOICES_MODE)
lerp_meth = gr.Dropdown(label=LABEL_LERP_METH, value=lambda: DEFAULT_LERP_METH, choices=CHOICES_LERP_METH)
replace_dim = gr.Dropdown(label=LABEL_REPLACE_DIM, value=lambda: DEFAULT_REPLACE_DIM, choices=CHOICES_REPLACE_DIM, visible=False)
replace_order = gr.Dropdown(label=LABEL_REPLACE_ORDER, value=lambda: DEFAULT_REPLACE_ORDER, choices=CHOICES_REPLACE_ORDER, visible=False)
def switch_mode(mode:str):
show_meth = Mode(mode) == Mode.LINEAR
show_repl = Mode(mode) == Mode.REPLACE
return [gr_show(x) for x in [show_meth, show_repl, show_repl]]
mode.change(switch_mode, inputs=[mode], outputs=[lerp_meth, replace_dim, replace_order], show_progress=False)
with gr.Row(variant='compact') as tab_param:
steps = gr.Text (label=LABEL_STEPS, value=lambda: DEFAULT_STEPS, max_lines=1)
genesis = gr.Dropdown(label=LABEL_GENESIS, value=lambda: DEFAULT_GENESIS, choices=CHOICES_GENESIS)
denoise_w = gr.Slider (label=LABEL_DENOISE_W, value=lambda: DEFAULT_DENOISE_W, minimum=0.0, maximum=1.0, visible=False)
embryo_step = gr.Text (label=LABEL_EMBRYO_STEP, value=lambda: DEFAULT_EMBRYO_STEP, max_lines=1, visible=False)
def switch_genesis(genesis:str):
show_dw = Gensis(genesis) == Gensis.SUCCESSIVE # show 'denoise_w' for 'successive'
show_es = Gensis(genesis) == Gensis.EMBRYO # show 'embryo_step' for 'embryo'
return [gr_show(x) for x in [show_dw, show_es]]
genesis.change(switch_genesis, inputs=[genesis], outputs=[denoise_w, embryo_step], show_progress=False)
with gr.Row(variant='compact', visible=DEFAULT_DEPTH) as tab_ext_depth:
depth_img = gr.Image(label=LABEL_DEPTH_IMG, source='upload', type='pil', image_mode=None)
with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale:
upscale_meth = gr.Dropdown(label=LABEL_UPSCALE_METH, value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICES_UPSCALER)
upscale_ratio = gr.Slider (label=LABEL_UPSCALE_RATIO, value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=16.0, step=0.1)
upscale_width = gr.Slider (label=LABEL_UPSCALE_WIDTH, value=lambda: DEFAULT_UPSCALE_WIDTH, minimum=0, maximum=2048, step=8)
upscale_height = gr.Slider (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0, maximum=2048, step=8)
with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video:
video_fmt = gr.Dropdown(label=LABEL_VIDEO_FMT, value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT)
video_fps = gr.Number (label=LABEL_VIDEO_FPS, value=lambda: DEFAULT_VIDEO_FPS)
video_pad = gr.Number (label=LABEL_VIDEO_PAD, value=lambda: DEFAULT_VIDEO_PAD, precision=0)
video_pick = gr.Text (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1)
with gr.Row(variant='compact') as tab_ext:
ext_video = gr.Checkbox(label=LABEL_VIDEO, value=lambda: DEFAULT_VIDEO)
ext_upscale = gr.Checkbox(label=LABEL_UPSCALE, value=lambda: DEFAULT_UPSCALE)
ext_depth = gr.Checkbox(label=LABEL_DEPTH, value=lambda: DEFAULT_DEPTH)
ext_video .change(gr_show, inputs=ext_video, outputs=tab_ext_video, show_progress=False)
ext_upscale.change(gr_show, inputs=ext_upscale, outputs=tab_ext_upscale, show_progress=False)
ext_depth .change(gr_show, inputs=ext_depth, outputs=tab_ext_depth, show_progress=False)
return [
mode, lerp_meth, replace_dim, replace_order,
steps, genesis, denoise_w, embryo_step,
depth_img,
upscale_meth, upscale_ratio, upscale_width, upscale_height,
video_fmt, video_fps, video_pad, video_pick,
ext_video, ext_upscale, ext_depth,
]
def run(self, p:Processing,
mode:str, lerp_meth:str, replace_dim:str, replace_order:str,
steps:str, genesis:str, denoise_w:float, embryo_step:str,
depth_img:PILImage,
upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int,
video_fmt:str, video_fps:float, video_pad:int, video_pick:str,
ext_video:bool, ext_upscale:bool, ext_depth:bool,
):
# enum lookup
mode: Mode = Mode(mode)
lerp_meth: LerpMethod = LerpMethod(lerp_meth)
replace_dim: ModeReplaceDim = ModeReplaceDim(replace_dim)
replace_order: ModeReplaceOrder = ModeReplaceOrder(replace_order)
genesis: Gensis = Gensis(genesis)
video_fmt: VideoFormat = VideoFormat(video_fmt)
# Param check & type convert
if ext_video:
if video_pad < 0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}')
if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}')
try: video_slice = parse_slice(video_pick)
except: return Processed(p, [], p.seed, 'syntax error in video_slice')
if genesis == Gensis.EMBRYO:
try: x = float(embryo_step)
except: return Processed(p, [], p.seed, f'embryo_step is not a number: {embryo_step}')
if x <= 0: return Processed(p, [], p.seed, f'embryo_step must > 0, but got {embryo_step}')
embryo_step: int = round(x * p.steps if x < 1.0 else x) ; del x
# Prepare prompts & steps
prompt_pos = p.prompt.strip()
if not prompt_pos: return Processed(p, [], p.seed, 'positive prompt should not be empty :(')
pos_prompts = [p.strip() for p in prompt_pos.split('\n') if p.strip()]
if len(pos_prompts) == 1: return Processed(p, [], p.seed, 'should specify at least two lines of prompt to travel between :(')
if genesis == Gensis.EMBRYO and len(pos_prompts) > 2: return Processed(p, [], p.seed, 'processing with "embryo" genesis takes exactly two lines of prompt :(')
prompt_neg = p.negative_prompt.strip()
neg_prompts = [p.strip() for p in prompt_neg.split('\n') if p.strip()]
if len(neg_prompts) == 0: neg_prompts = ['']
n_stages = max(len(pos_prompts), len(neg_prompts))
while len(pos_prompts) < n_stages: pos_prompts.append(pos_prompts[-1])
while len(neg_prompts) < n_stages: neg_prompts.append(neg_prompts[-1])
try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')]
except: return Processed(p, [], p.seed, f'cannot parse steps option: {steps}')
if len(steps) == 1:
steps = [steps[0]] * (n_stages - 1)
elif len(steps) != n_stages - 1:
return Processed(p, [], p.seed, f'stage count mismatch: you have {n_stages} prompt stages, but specified {len(steps)} steps; should assure len(steps) == len(stages) - 1')
n_frames = sum(steps) + n_stages
if 'show_debug':
print('n_stages:', n_stages)
print('n_frames:', n_frames)
print('steps:', steps)
steps.insert(0, -1) # fixup the first stage
# Custom saving path
travel_path = os.path.join(p.outpath_samples, 'prompt_travel')
os.makedirs(travel_path, exist_ok=True)
travel_number = get_next_sequence_number(travel_path)
self.log_dp = os.path.join(travel_path, f'{travel_number:05}')
p.outpath_samples = self.log_dp
os.makedirs(self.log_dp, exist_ok=True)
#self.log_fp = os.path.join(self.log_dp, 'log.txt')
# Force batch count and batch size to 1
p.n_iter = 1
p.batch_size = 1
# Random unified const seed
p.seed = get_fixed_seed(p.seed) # fix it to assure all processes using the same major seed
self.subseed = p.subseed # stash it to allow using random subseed for each process (when -1)
if 'show_debug':
print('seed:', p.seed)
print('subseed:', p.subseed)
print('subseed_strength:', p.subseed_strength)
# Start job
state.job_count = n_frames
# Pack parameters
self.pos_prompts = pos_prompts
self.neg_prompts = neg_prompts
self.steps = steps
self.genesis = genesis
self.denoise_w = denoise_w
self.embryo_step = embryo_step
self.lerp_meth = lerp_meth
self.replace_dim = replace_dim
self.replace_order = replace_order
self.n_stages = n_stages
self.n_frames = n_frames
def upscale_image_callback(params:ImageSaveParams):
params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height)
# Dispatch
self.p: Processing = p
self.images: PILImages = []
self.info: StrRef = Ref()
try:
if ext_upscale: on_before_image_saved(upscale_image_callback)
if ext_depth: self.ext_depth_preprocess(p, depth_img)
runner = getattr(self, f'run_{mode.value}')
if not runner: return Processed(p, [], p.seed, f'no runner found for mode: {mode.value}')
runner()
except:
e = format_exc()
print(e)
self.info.value = e
finally:
if ext_depth: self.ext_depth_postprocess(p, depth_img)
if ext_upscale: remove_callbacks_for_function(upscale_image_callback)
# Save video
if ext_video: save_video(self.images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}'))
return Processed(p, self.images, p.seed, self.info.value)
def run_linear(self):
# dispatch for special case
if self.genesis == Gensis.EMBRYO: return self.run_linear_embryo()
lerp_fn = weighted_sum if self.lerp_meth == LerpMethod.LERP else geometric_slerp
if 'auxiliary':
switch_to_stage = partial(switch_to_stage_binding_, self)
process_p = partial(process_p_binding_, self)
from_pos_hidden: TensorRef = Ref()
from_neg_hidden: TensorRef = Ref()
to_pos_hidden: TensorRef = Ref()
to_neg_hidden: TensorRef = Ref()
inter_pos_hidden: TensorRef = Ref()
inter_neg_hidden: TensorRef = Ref()
# Step 1: draw the init image
switch_to_stage(0)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden, from_neg_hidden])):
process_p()
# travel through stages
for i in range(1, self.n_stages):
if state.interrupted: break
state.job = f'{i}/{self.n_frames}'
state.job_no = i + 1
# only change target prompts
switch_to_stage(i)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden, to_neg_hidden])):
if self.genesis == Gensis.FIXED:
imgs = process_p(append=False) # stash it to make order right
elif self.genesis == Gensis.SUCCESSIVE:
with p_steps_overrider(self.p, steps=1): # ignore final image, only need cond
process_p(save=False, append=False)
else: raise ValueError(f'invalid genesis: {self.genesis.value}')
# Step 2: draw the interpolated images
is_break_iter = False
n_inter = self.steps[i]
for t in range(1, n_inter + (1 if self.genesis == Gensis.SUCCESSIVE else 0)):
if state.interrupted: is_break_iter = True ; break
alpha = t / n_inter # [1/T, 2/T, .. T-1/T] (+ [T/T])?
inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, alpha)
inter_neg_hidden.value = lerp_fn(from_neg_hidden.value, to_neg_hidden.value, alpha)
with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden, inter_neg_hidden])):
process_p()
if is_break_iter: break
# Step 3: append the final stage
if self.genesis != Gensis.SUCCESSIVE: self.images.extend(imgs)
# move to next stage
from_pos_hidden.value, from_neg_hidden.value = to_pos_hidden.value, to_neg_hidden.value
inter_pos_hidden.value, inter_neg_hidden.value = None, None
def run_linear_embryo(self):
''' NOTE: this procedure has special logic, we separate it from run_linear() so far '''
lerp_fn = weighted_sum if self.lerp_meth == LerpMethod.LERP else geometric_slerp
n_frames = self.steps[1] + 2
if 'auxiliary':
switch_to_stage = partial(switch_to_stage_binding_, self)
process_p = partial(process_p_binding_, self)
from_pos_hidden: TensorRef = Ref()
to_pos_hidden: TensorRef = Ref()
inter_pos_hidden: TensorRef = Ref()
embryo: TensorRef = Ref() # latent image, the common half-denoised prototype of all frames
# Step 1: get starting & ending condition
switch_to_stage(0)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden])):
with p_steps_overrider(self.p, steps=1):
process_p(save=False)
switch_to_stage(1)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden])):
with p_steps_overrider(self.p, steps=1):
process_p(save=False)
# Step 2: get the condition middle-point as embryo then hatch it halfway
inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, 0.5)
with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])):
with on_cfg_denoiser_wrapper(partial(get_latent_callback, embryo, self.embryo_step)):
process_p(save=False)
try:
img: PILImage = single_sample_to_image(embryo.value[0], approximation=-1) # the data is duplicated, just get first item
img.save(os.path.join(self.log_dp, 'embryo.png'))
except: pass
# Step 3: derive the embryo towards each interpolated condition
for t in range(0, n_frames+1):
if state.interrupted: break
alpha = t / n_frames # [0, 1/T, 2/T, .. T-1/T, 1]
inter_pos_hidden.value = lerp_fn(from_pos_hidden.value, to_pos_hidden.value, alpha)
with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])):
with on_cfg_denoiser_wrapper(partial(set_latent_callback, embryo, self.embryo_step)):
process_p()
def run_replace(self):
''' yet another replace method, but do replacing on the condition tensor by token dim or channel dim '''
if self.genesis == Gensis.EMBRYO: raise NotImplementedError(f'genesis {self.genesis.value!r} is only supported in linear mode currently :(')
if 'auxiliary':
switch_to_stage = partial(switch_to_stage_binding_, self)
process_p = partial(process_p_binding_, self)
from_pos_hidden: TensorRef = Ref()
to_pos_hidden: TensorRef = Ref()
inter_pos_hidden: TensorRef = Ref()
# Step 1: draw the init image
switch_to_stage(0)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [from_pos_hidden])):
process_p()
# travel through stages
for i in range(1, self.n_stages):
if state.interrupted: break
state.job = f'{i}/{self.n_frames}'
state.job_no = i + 1
# only change target prompts
switch_to_stage(i)
with on_cfg_denoiser_wrapper(partial(get_cond_callback, [to_pos_hidden])):
if self.genesis == Gensis.FIXED:
imgs = process_p(append=False) # stash it to make order right
elif self.genesis == Gensis.SUCCESSIVE:
with p_steps_overrider(self.p, steps=1): # ignore final image, only need cond
process_p(save=False, append=False)
else: raise ValueError(f'invalid genesis: {self.genesis.value}')
# ========== βββ major differences from run_linear() βββ ==========
# decide change portion in each iter
L1 = torch.abs(from_pos_hidden.value - to_pos_hidden.value)
if self.replace_dim == ModeReplaceDim.RANDOM:
dist = L1 # [T=77, D=768]
elif self.replace_dim == ModeReplaceDim.TOKEN:
dist = L1.mean(axis=1) # [T=77]
elif self.replace_dim == ModeReplaceDim.CHANNEL:
dist = L1.mean(axis=0) # [D=768]
else: raise ValueError(f'unknown replace_dim: {self.replace_dim}')
mask = dist > EPS
dist = torch.where(mask, dist, 0.0)
n_diff = mask.sum().item() # when value differs we have mask==True
n_inter = self.steps[i] + 1
replace_count = int(n_diff / n_inter) + 1 # => accumulative modifies [1/T, 2/T, .. T-1/T] of total cond
# Step 2: draw the replaced images
inter_pos_hidden.value = from_pos_hidden.value
is_break_iter = False
for _ in range(1, n_inter):
if state.interrupted: is_break_iter = True ; break
inter_pos_hidden.value = replace_until_match(inter_pos_hidden.value, to_pos_hidden.value, replace_count, dist=dist, order=self.replace_order)
with on_cfg_denoiser_wrapper(partial(set_cond_callback, [inter_pos_hidden])):
process_p()
# ========== βββ major differences from run_linear() βββ ==========
if is_break_iter: break
# Step 3: append the final stage
if self.genesis != Gensis.SUCCESSIVE: self.images.extend(imgs)
# move to next stage
from_pos_hidden.value = to_pos_hidden.value
inter_pos_hidden.value = None
''' βββ extension support βββ '''
def ext_depth_preprocess(self, p:Processing, depth_img:PILImage): # copy from repo `AnonymousCervine/depth-image-io-for-SDWebui`
from types import MethodType
from einops import repeat, rearrange
import modules.shared as shared
import modules.devices as devices
def sanitize_pil_image_mode(img):
if img.mode in {'P', 'CMYK', 'HSV'}:
img = img.convert(mode='RGB')
return img
def alt_depth_image_conditioning(self, source_image):
with devices.autocast():
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
depth_data = np.array(sanitize_pil_image_mode(depth_img))
if len(np.shape(depth_data)) == 2:
depth_data = rearrange(depth_data, "h w -> 1 1 h w")
else:
depth_data = rearrange(depth_data, "h w c -> c 1 1 h w")[0]
depth_data = torch.from_numpy(depth_data).to(device=shared.device).to(dtype=torch.float32)
depth_data = repeat(depth_data, "1 ... -> n ...", n=self.batch_size)
conditioning = torch.nn.functional.interpolate(
depth_data,
size=conditioning_image.shape[2:],
mode="bicubic",
align_corners=False,
)
(depth_min, depth_max) = torch.aminmax(conditioning)
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
p.depth2img_image_conditioning = MethodType(alt_depth_image_conditioning, p)
def alt_txt2img_image_conditioning(self, x, width=None, height=None):
fake_img = torch.zeros(1, 3, height or self.height, width or self.width).to(shared.device).type(self.sd_model.dtype)
return self.depth2img_image_conditioning(fake_img)
p.txt2img_image_conditioning = MethodType(alt_txt2img_image_conditioning, p)
def ext_depth_postprocess(self, p:Processing, depth_img:PILImage):
depth_img.close()
|