"),
- "local_preview": path + ".png",
- }
-
- def allowed_directories_for_previews(self):
- return [shared.cmd_opts.lora_dir]
-
diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py
deleted file mode 100644
index f12c5b90ed2984ef16d8d8dd30d1ebef34cbf7c3..0000000000000000000000000000000000000000
--- a/extensions-builtin/ScuNET/preload.py
+++ /dev/null
@@ -1,6 +0,0 @@
-import os
-from modules import paths
-
-
-def preload(parser):
- parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
deleted file mode 100644
index e0fbf3a33747f447d396dd0d564e92c904cfabac..0000000000000000000000000000000000000000
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from scunet_model_arch import SCUNet as net
-
-
-class UpscalerScuNET(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "ScuNET"
- self.model_name = "ScuNET GAN"
- self.model_name2 = "ScuNET PSNR"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
- self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pth"])
- scalers = []
- add_model2 = True
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- if name == self.model_name2 or file == self.model_url2:
- add_model2 = False
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- if add_model2:
- scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
- scalers.append(scaler_data2)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
-
- model = self.load_model(selected_file)
- if model is None:
- return img
-
- device = devices.get_device_for('scunet')
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
-
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- device = devices.get_device_for('scunet')
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- model = model.to(device)
-
- return model
-
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
deleted file mode 100644
index 43ca8d36fe57a12dcad58e8b06ee2e0774494b0e..0000000000000000000000000000000000000000
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ /dev/null
@@ -1,265 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-import torch.nn as nn
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from timm.models.layers import trunc_normal_, DropPath
-
-
-class WMSA(nn.Module):
- """ Self-attention module in Swin Transformer
- """
-
- def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.head_dim = head_dim
- self.scale = self.head_dim ** -0.5
- self.n_heads = input_dim // head_dim
- self.window_size = window_size
- self.type = type
- self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
-
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
-
- self.linear = nn.Linear(self.input_dim, self.output_dim)
-
- trunc_normal_(self.relative_position_params, std=.02)
- self.relative_position_params = torch.nn.Parameter(
- self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
- 2).transpose(
- 0, 1))
-
- def generate_mask(self, h, w, p, shift):
- """ generating the mask of SW-MSA
- Args:
- shift: shift parameters in CyclicShift.
- Returns:
- attn_mask: should be (1 1 w p p),
- """
- # supporting square.
- attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
- if self.type == 'W':
- return attn_mask
-
- s = p - shift
- attn_mask[-1, :, :s, :, s:, :] = True
- attn_mask[-1, :, s:, :, :s, :] = True
- attn_mask[:, -1, :, :s, :, s:] = True
- attn_mask[:, -1, :, s:, :, :s] = True
- attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
- return attn_mask
-
- def forward(self, x):
- """ Forward pass of Window Multi-head Self-attention module.
- Args:
- x: input tensor with shape of [b h w c];
- attn_mask: attention mask, fill -inf where the value is True;
- Returns:
- output: tensor shape [b h w c]
- """
- if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
- x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
- h_windows = x.size(1)
- w_windows = x.size(2)
- # square validation
- # assert h_windows == w_windows
-
- x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
- qkv = self.embedding_layer(x)
- q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
- sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
- # Adding learnable relative embedding
- sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
- # Using Attn Mask to distinguish different subwindows.
- if self.type != 'W':
- attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
- sim = sim.masked_fill_(attn_mask, float("-inf"))
-
- probs = nn.functional.softmax(sim, dim=-1)
- output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
- output = rearrange(output, 'h b w p c -> b w p (h c)')
- output = self.linear(output)
- output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
-
- if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2))
- return output
-
- def relative_embedding(self):
- cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
- relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
- # negative is allowed
- return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
-
-
-class Block(nn.Module):
- def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer Block
- """
- super(Block, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- assert type in ['W', 'SW']
- self.type = type
- if input_resolution <= window_size:
- self.type = 'W'
-
- self.ln1 = nn.LayerNorm(input_dim)
- self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.ln2 = nn.LayerNorm(input_dim)
- self.mlp = nn.Sequential(
- nn.Linear(input_dim, 4 * input_dim),
- nn.GELU(),
- nn.Linear(4 * input_dim, output_dim),
- )
-
- def forward(self, x):
- x = x + self.drop_path(self.msa(self.ln1(x)))
- x = x + self.drop_path(self.mlp(self.ln2(x)))
- return x
-
-
-class ConvTransBlock(nn.Module):
- def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer and Conv Block
- """
- super(ConvTransBlock, self).__init__()
- self.conv_dim = conv_dim
- self.trans_dim = trans_dim
- self.head_dim = head_dim
- self.window_size = window_size
- self.drop_path = drop_path
- self.type = type
- self.input_resolution = input_resolution
-
- assert self.type in ['W', 'SW']
- if self.input_resolution <= self.window_size:
- self.type = 'W'
-
- self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
- self.type, self.input_resolution)
- self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
- self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
-
- self.conv_block = nn.Sequential(
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
- )
-
- def forward(self, x):
- conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
- conv_x = self.conv_block(conv_x) + conv_x
- trans_x = Rearrange('b c h w -> b h w c')(trans_x)
- trans_x = self.trans_block(trans_x)
- trans_x = Rearrange('b h w c -> b c h w')(trans_x)
- res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
- x = x + res
-
- return x
-
-
-class SCUNet(nn.Module):
- # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
- def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
- super(SCUNet, self).__init__()
- if config is None:
- config = [2, 2, 2, 2, 2, 2, 2]
- self.config = config
- self.dim = dim
- self.head_dim = 32
- self.window_size = 8
-
- # drop path rate for each layer
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
-
- self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
-
- begin = 0
- self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[0])] + \
- [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
-
- begin += config[0]
- self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[1])] + \
- [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
-
- begin += config[1]
- self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[2])] + \
- [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
-
- begin += config[2]
- self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 8)
- for i in range(config[3])]
-
- begin += config[3]
- self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[4])]
-
- begin += config[4]
- self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[5])]
-
- begin += config[5]
- self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[6])]
-
- self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
-
- self.m_head = nn.Sequential(*self.m_head)
- self.m_down1 = nn.Sequential(*self.m_down1)
- self.m_down2 = nn.Sequential(*self.m_down2)
- self.m_down3 = nn.Sequential(*self.m_down3)
- self.m_body = nn.Sequential(*self.m_body)
- self.m_up3 = nn.Sequential(*self.m_up3)
- self.m_up2 = nn.Sequential(*self.m_up2)
- self.m_up1 = nn.Sequential(*self.m_up1)
- self.m_tail = nn.Sequential(*self.m_tail)
- # self.apply(self._init_weights)
-
- def forward(self, x0):
-
- h, w = x0.size()[-2:]
- paddingBottom = int(np.ceil(h / 64) * 64 - h)
- paddingRight = int(np.ceil(w / 64) * 64 - w)
- x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
-
- x1 = self.m_head(x0)
- x2 = self.m_down1(x1)
- x3 = self.m_down2(x2)
- x4 = self.m_down3(x3)
- x = self.m_body(x4)
- x = self.m_up3(x + x4)
- x = self.m_up2(x + x3)
- x = self.m_up1(x + x2)
- x = self.m_tail(x + x1)
-
- x = x[..., :h, :w]
-
- return x
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py
deleted file mode 100644
index 567e44bcaa1d40ca5dabed7744c94c5b2c68e87f..0000000000000000000000000000000000000000
--- a/extensions-builtin/SwinIR/preload.py
+++ /dev/null
@@ -1,6 +0,0 @@
-import os
-from modules import paths
-
-
-def preload(parser):
- parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
deleted file mode 100644
index e8783bca153954afd086536a6dee854ec5e17ba9..0000000000000000000000000000000000000000
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import contextlib
-import os
-
-import numpy as np
-import torch
-from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
-from tqdm import tqdm
-
-from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
-from modules.upscaler import Upscaler, UpscalerData
-
-
-device_swinir = devices.get_device_for('swinir')
-
-
-class UpscalerSwinIR(Upscaler):
- def __init__(self, dirname):
- self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
- self.model_name = "SwinIR 4x"
- self.user_path = dirname
- super().__init__()
- scalers = []
- model_files = self.find_models(ext_filter=[".pt", ".pth"])
- for model in model_files:
- if "http" in model:
- name = self.model_name
- else:
- name = modelloader.friendly_name(model)
- model_data = UpscalerData(name, model, self)
- scalers.append(model_data)
- self.scalers = scalers
-
- def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
- return img
- model = model.to(device_swinir, dtype=devices.dtype)
- img = upscale(img, model)
- try:
- torch.cuda.empty_cache()
- except:
- pass
- return img
-
- def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
- else:
- filename = path
- if filename is None or not os.path.exists(filename):
- return None
- if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
- )
- params = None
- else:
- model = net(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
- params = "params_ema"
-
- pretrained_model = torch.load(filename)
- if params is not None:
- model.load_state_dict(pretrained_model[params], strict=True)
- else:
- model.load_state_dict(pretrained_model, strict=True)
- return model
-
-
-def upscale(
- img,
- model,
- tile=None,
- tile_overlap=None,
- window_size=8,
- scale=4,
-):
- tile = tile or opts.SWIN_tile
- tile_overlap = tile_overlap or opts.SWIN_tile_overlap
-
-
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
- with torch.no_grad(), devices.autocast():
- _, _, h_old, w_old = img.size()
- h_pad = (h_old // window_size + 1) * window_size - h_old
- w_pad = (w_old // window_size + 1) * window_size - w_old
- img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
- img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
- output = output[..., : h_old * scale, : w_old * scale]
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
- if output.ndim == 3:
- output = np.transpose(
- output[[2, 1, 0], :, :], (1, 2, 0)
- ) # CHW-RGB to HCW-BGR
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return Image.fromarray(output, "RGB")
-
-
-def inference(img, model, tile, tile_overlap, window_size, scale):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- if state.interrupted or state.skipped:
- break
-
- for w_idx in w_idx_list:
- if state.interrupted or state.skipped:
- break
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
-
-def on_ui_settings():
- import gradio as gr
-
- shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
- shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
-
-
-script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
deleted file mode 100644
index 863f42db6f50e5eac70931b8c0e6443f831a6018..0000000000000000000000000000000000000000
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ /dev/null
@@ -1,867 +0,0 @@
-# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r""" SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(SwinIR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- if self.upscale == 4:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = SwinIR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
deleted file mode 100644
index 0e28ae6eefa2f4bc6260b14760907c54ce633876..0000000000000000000000000000000000000000
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ /dev/null
@@ -1,1017 +0,0 @@
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, ' \
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
\ No newline at end of file
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
deleted file mode 100644
index 4a85c8ebf25110e911a6a1021fae6a014aa11000..0000000000000000000000000000000000000000
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ /dev/null
@@ -1,110 +0,0 @@
-// Stable Diffusion WebUI - Bracket checker
-// Version 1.0
-// By Hingashi no Florin/Bwin4L
-// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
-// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
-
-function checkBrackets(evt, textArea, counterElt) {
- errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
- errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
- errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
-
- openBracketRegExp = /\(/g;
- closeBracketRegExp = /\)/g;
-
- openSquareBracketRegExp = /\[/g;
- closeSquareBracketRegExp = /\]/g;
-
- openCurlyBracketRegExp = /\{/g;
- closeCurlyBracketRegExp = /\}/g;
-
- totalOpenBracketMatches = 0;
- totalCloseBracketMatches = 0;
- totalOpenSquareBracketMatches = 0;
- totalCloseSquareBracketMatches = 0;
- totalOpenCurlyBracketMatches = 0;
- totalCloseCurlyBracketMatches = 0;
-
- openBracketMatches = textArea.value.match(openBracketRegExp);
- if(openBracketMatches) {
- totalOpenBracketMatches = openBracketMatches.length;
- }
-
- closeBracketMatches = textArea.value.match(closeBracketRegExp);
- if(closeBracketMatches) {
- totalCloseBracketMatches = closeBracketMatches.length;
- }
-
- openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
- if(openSquareBracketMatches) {
- totalOpenSquareBracketMatches = openSquareBracketMatches.length;
- }
-
- closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
- if(closeSquareBracketMatches) {
- totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
- }
-
- openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
- if(openCurlyBracketMatches) {
- totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
- }
-
- closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
- if(closeCurlyBracketMatches) {
- totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
- }
-
- if(totalOpenBracketMatches != totalCloseBracketMatches) {
- if(!counterElt.title.includes(errorStringParen)) {
- counterElt.title += errorStringParen;
- }
- } else {
- counterElt.title = counterElt.title.replace(errorStringParen, '');
- }
-
- if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
- if(!counterElt.title.includes(errorStringSquare)) {
- counterElt.title += errorStringSquare;
- }
- } else {
- counterElt.title = counterElt.title.replace(errorStringSquare, '');
- }
-
- if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
- if(!counterElt.title.includes(errorStringCurly)) {
- counterElt.title += errorStringCurly;
- }
- } else {
- counterElt.title = counterElt.title.replace(errorStringCurly, '');
- }
-
- if(counterElt.title != '') {
- counterElt.classList.add('error');
- } else {
- counterElt.classList.remove('error');
- }
-}
-
-function setupBracketChecking(id_prompt, id_counter){
- var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
- var counter = gradioApp().getElementById(id_counter)
- textarea.addEventListener("input", function(evt){
- checkBrackets(evt, textarea, counter)
- });
-}
-
-var shadowRootLoaded = setInterval(function() {
- var shadowRoot = document.querySelector('gradio-app').shadowRoot;
- if(! shadowRoot) return false;
-
- var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
- if(shadowTextArea.length < 1) return false;
-
- clearInterval(shadowRootLoaded);
-
- setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
- setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
- setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
- setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
-}, 1000);
diff --git a/extensions/put extensions here.txt b/extensions/put extensions here.txt
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/extensions/sd-webui-controlnet b/extensions/sd-webui-controlnet
deleted file mode 160000
index 06a481385942fa2fe84aef24633862711021827f..0000000000000000000000000000000000000000
--- a/extensions/sd-webui-controlnet
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 06a481385942fa2fe84aef24633862711021827f
diff --git a/html/card-no-preview.png b/html/card-no-preview.png
deleted file mode 100644
index e2beb2692067db56ac5f7bd5bfc3d895d9063371..0000000000000000000000000000000000000000
Binary files a/html/card-no-preview.png and /dev/null differ
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
deleted file mode 100644
index 8a5e2fbd223e71abacca9a602bd1be154f5fb520..0000000000000000000000000000000000000000
--- a/html/extra-networks-card.html
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html
deleted file mode 100644
index 389358d6c4b383fdc3c5686e029e7b3b1ae9a493..0000000000000000000000000000000000000000
--- a/html/extra-networks-no-cards.html
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
Nothing here. Add some content to the following directories:
-
-
-
-
diff --git a/html/footer.html b/html/footer.html
deleted file mode 100644
index bad87ff619df2488b9732017a1e5658e240adb2e..0000000000000000000000000000000000000000
--- a/html/footer.html
+++ /dev/null
@@ -1,13 +0,0 @@
-
-
-
-{versions}
-
diff --git a/html/image-update.svg b/html/image-update.svg
deleted file mode 100644
index 3abf12df0f7774c13203e3c49ec3544649df42f4..0000000000000000000000000000000000000000
--- a/html/image-update.svg
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-
-
-
-
diff --git a/html/licenses.html b/html/licenses.html
deleted file mode 100644
index 570630eb4ada6511ac5c8cd6c06f405a1b55c2bd..0000000000000000000000000000000000000000
--- a/html/licenses.html
+++ /dev/null
@@ -1,419 +0,0 @@
-
-
-
-Parts of CodeFormer code had to be copied to be compatible with GFPGAN.
-
-S-Lab License 1.0
-
-Copyright 2022 S-Lab
-
-Redistribution and use for non-commercial purpose in source and
-binary forms, with or without modification, are permitted provided
-that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-In the event that redistribution and/or use for commercial purpose in
-source or binary forms, with or without modification is required,
-please contact the contributor(s) of the work.
-
-
-
-
-Code for architecture and reading models copied.
-
-MIT License
-
-Copyright (c) 2021 victorca25
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Some code is copied to support ESRGAN models.
-
-BSD 3-Clause License
-
-Copyright (c) 2021, Xintao Wang
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived from
- this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-
-Some code for compatibility with OSX is taken from lstein's repository.
-
-MIT License
-
-Copyright (c) 2022 InvokeAI Team
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Code added by contirubtors, most likely copied from this repository.
-
-MIT License
-
-Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Some small amounts of code borrowed and reworked.
-
-MIT License
-
-Copyright (c) 2022 pharmapsychotic
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
-
-Code added by contributors, most likely copied from this repository.
-
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [2021] [SwinIR Authors]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
-
-
-The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
-
-MIT License
-
-Copyright (c) 2023 Alex Birch
-Copyright (c) 2023 Amin Rezaei
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
-
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
deleted file mode 100644
index 0f164b82c1a1f9dd2ad0e6a745bcdd7e652a53e6..0000000000000000000000000000000000000000
--- a/javascript/aspectRatioOverlay.js
+++ /dev/null
@@ -1,113 +0,0 @@
-
-let currentWidth = null;
-let currentHeight = null;
-let arFrameTimeout = setTimeout(function(){},0);
-
-function dimensionChange(e, is_width, is_height){
-
- if(is_width){
- currentWidth = e.target.value*1.0
- }
- if(is_height){
- currentHeight = e.target.value*1.0
- }
-
- var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
-
- if(!inImg2img){
- return;
- }
-
- var targetElement = null;
-
- var tabIndex = get_tab_index('mode_img2img')
- if(tabIndex == 0){ // img2img
- targetElement = gradioApp().querySelector('div[data-testid=image] img');
- } else if(tabIndex == 1){ //Sketch
- targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
- } else if(tabIndex == 2){ // Inpaint
- targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
- } else if(tabIndex == 3){ // Inpaint sketch
- targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
- }
-
-
- if(targetElement){
-
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(!arPreviewRect){
- arPreviewRect = document.createElement('div')
- arPreviewRect.id = "imageARPreview";
- gradioApp().getRootNode().appendChild(arPreviewRect)
- }
-
-
-
- var viewportOffset = targetElement.getBoundingClientRect();
-
- viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
-
- scaledx = targetElement.naturalWidth*viewportscale
- scaledy = targetElement.naturalHeight*viewportscale
-
- cleintRectTop = (viewportOffset.top+window.scrollY)
- cleintRectLeft = (viewportOffset.left+window.scrollX)
- cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
- cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
-
- viewRectTop = cleintRectCentreY-(scaledy/2)
- viewRectLeft = cleintRectCentreX-(scaledx/2)
- arRectWidth = scaledx
- arRectHeight = scaledy
-
- arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
- arscaledx = currentWidth*arscale
- arscaledy = currentHeight*arscale
-
- arRectTop = cleintRectCentreY-(arscaledy/2)
- arRectLeft = cleintRectCentreX-(arscaledx/2)
- arRectWidth = arscaledx
- arRectHeight = arscaledy
-
- arPreviewRect.style.top = arRectTop+'px';
- arPreviewRect.style.left = arRectLeft+'px';
- arPreviewRect.style.width = arRectWidth+'px';
- arPreviewRect.style.height = arRectHeight+'px';
-
- clearTimeout(arFrameTimeout);
- arFrameTimeout = setTimeout(function(){
- arPreviewRect.style.display = 'none';
- },2000);
-
- arPreviewRect.style.display = 'block';
-
- }
-
-}
-
-
-onUiUpdate(function(){
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(arPreviewRect){
- arPreviewRect.style.display = 'none';
- }
- var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
- if(inImg2img){
- let inputs = gradioApp().querySelectorAll('input');
- inputs.forEach(function(e){
- var is_width = e.parentElement.id == "img2img_width"
- var is_height = e.parentElement.id == "img2img_height"
-
- if((is_width || is_height) && !e.classList.contains('scrollwatch')){
- e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
- e.classList.add('scrollwatch')
- }
- if(is_width){
- currentWidth = e.value*1.0
- }
- if(is_height){
- currentHeight = e.value*1.0
- }
- })
- }
-});
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
deleted file mode 100644
index 11bcce1bcbdc0ed5c1004fbd9d971d255645826b..0000000000000000000000000000000000000000
--- a/javascript/contextMenus.js
+++ /dev/null
@@ -1,177 +0,0 @@
-
-contextMenuInit = function(){
- let eventListenerApplied=false;
- let menuSpecs = new Map();
-
- const uid = function(){
- return Date.now().toString(36) + Math.random().toString(36).substr(2);
- }
-
- function showContextMenu(event,element,menuEntries){
- let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
- let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
-
- let tabButton = uiCurrentTab
- let baseStyle = window.getComputedStyle(tabButton)
-
- const contextMenu = document.createElement('nav')
- contextMenu.id = "context-menu"
- contextMenu.style.background = baseStyle.background
- contextMenu.style.color = baseStyle.color
- contextMenu.style.fontFamily = baseStyle.fontFamily
- contextMenu.style.top = posy+'px'
- contextMenu.style.left = posx+'px'
-
-
-
- const contextMenuList = document.createElement('ul')
- contextMenuList.className = 'context-menu-items';
- contextMenu.append(contextMenuList);
-
- menuEntries.forEach(function(entry){
- let contextMenuEntry = document.createElement('a')
- contextMenuEntry.innerHTML = entry['name']
- contextMenuEntry.addEventListener("click", function(e) {
- entry['func']();
- })
- contextMenuList.append(contextMenuEntry);
-
- })
-
- gradioApp().getRootNode().appendChild(contextMenu)
-
- let menuWidth = contextMenu.offsetWidth + 4;
- let menuHeight = contextMenu.offsetHeight + 4;
-
- let windowWidth = window.innerWidth;
- let windowHeight = window.innerHeight;
-
- if ( (windowWidth - posx) < menuWidth ) {
- contextMenu.style.left = windowWidth - menuWidth + "px";
- }
-
- if ( (windowHeight - posy) < menuHeight ) {
- contextMenu.style.top = windowHeight - menuHeight + "px";
- }
-
- }
-
- function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
-
- currentItems = menuSpecs.get(targetElementSelector)
-
- if(!currentItems){
- currentItems = []
- menuSpecs.set(targetElementSelector,currentItems);
- }
- let newItem = {'id':targetElementSelector+'_'+uid(),
- 'name':entryName,
- 'func':entryFunction,
- 'isNew':true}
-
- currentItems.push(newItem)
- return newItem['id']
- }
-
- function removeContextMenuOption(uid){
- menuSpecs.forEach(function(v,k) {
- let index = -1
- v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
- if(index>=0){
- v.splice(index, 1);
- }
- })
- }
-
- function addContextMenuEventListener(){
- if(eventListenerApplied){
- return;
- }
- gradioApp().addEventListener("click", function(e) {
- let source = e.composedPath()[0]
- if(source.id && source.id.indexOf('check_progress')>-1){
- return
- }
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- });
- gradioApp().addEventListener("contextmenu", function(e) {
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- menuSpecs.forEach(function(v,k) {
- if(e.composedPath()[0].matches(k)){
- showContextMenu(e,e.composedPath()[0],v)
- e.preventDefault()
- return
- }
- })
- });
- eventListenerApplied=true
-
- }
-
- return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
-}
-
-initResponse = contextMenuInit();
-appendContextMenuOption = initResponse[0];
-removeContextMenuOption = initResponse[1];
-addContextMenuEventListener = initResponse[2];
-
-(function(){
- //Start example Context Menu Items
- let generateOnRepeat = function(genbuttonid,interruptbuttonid){
- let genbutton = gradioApp().querySelector(genbuttonid);
- let interruptbutton = gradioApp().querySelector(interruptbuttonid);
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- clearInterval(window.generateOnRepeatInterval)
- window.generateOnRepeatInterval = setInterval(function(){
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- },
- 500)
- }
-
- appendContextMenuOption('#txt2img_generate','Generate forever',function(){
- generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
- })
- appendContextMenuOption('#img2img_generate','Generate forever',function(){
- generateOnRepeat('#img2img_generate','#img2img_interrupt');
- })
-
- let cancelGenerateForever = function(){
- clearInterval(window.generateOnRepeatInterval)
- }
-
- appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
-
- appendContextMenuOption('#roll','Roll three',
- function(){
- let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
- setTimeout(function(){rollbutton.click()},100)
- setTimeout(function(){rollbutton.click()},200)
- setTimeout(function(){rollbutton.click()},300)
- }
- )
-})();
-//End example Context Menu Items
-
-onUiUpdate(function(){
- addContextMenuEventListener()
-});
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
deleted file mode 100644
index fe00892481a8ff8b997759b21cb36eb364db788b..0000000000000000000000000000000000000000
--- a/javascript/dragdrop.js
+++ /dev/null
@@ -1,97 +0,0 @@
-// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
-
-function isValidImageList( files ) {
- return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
-}
-
-function dropReplaceImage( imgWrap, files ) {
- if ( ! isValidImageList( files ) ) {
- return;
- }
-
- const tmpFile = files[0];
-
- imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
- const callback = () => {
- const fileInput = imgWrap.querySelector('input[type="file"]');
- if ( fileInput ) {
- if ( files.length === 0 ) {
- files = new DataTransfer();
- files.items.add(tmpFile);
- fileInput.files = files.files;
- } else {
- fileInput.files = files;
- }
- fileInput.dispatchEvent(new Event('change'));
- }
- };
-
- if ( imgWrap.closest('#pnginfo_image') ) {
- // special treatment for PNG Info tab, wait for fetch request to finish
- const oldFetch = window.fetch;
- window.fetch = async (input, options) => {
- const response = await oldFetch(input, options);
- if ( 'api/predict/' === input ) {
- const content = await response.text();
- window.fetch = oldFetch;
- window.requestAnimationFrame( () => callback() );
- return new Response(content, {
- status: response.status,
- statusText: response.statusText,
- headers: response.headers
- })
- }
- return response;
- };
- } else {
- window.requestAnimationFrame( () => callback() );
- }
-}
-
-window.document.addEventListener('dragover', e => {
- const target = e.composedPath()[0];
- const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
- return;
- }
- e.stopPropagation();
- e.preventDefault();
- e.dataTransfer.dropEffect = 'copy';
-});
-
-window.document.addEventListener('drop', e => {
- const target = e.composedPath()[0];
- if (target.placeholder.indexOf("Prompt") == -1) {
- return;
- }
- const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap ) {
- return;
- }
- e.stopPropagation();
- e.preventDefault();
- const files = e.dataTransfer.files;
- dropReplaceImage( imgWrap, files );
-});
-
-window.addEventListener('paste', e => {
- const files = e.clipboardData.files;
- if ( ! isValidImageList( files ) ) {
- return;
- }
-
- const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
- .filter(el => uiElementIsVisible(el));
- if ( ! visibleImageFields.length ) {
- return;
- }
-
- const firstFreeImageField = visibleImageFields
- .filter(el => el.querySelector('input[type=file]'))?.[0];
-
- dropReplaceImage(
- firstFreeImageField ?
- firstFreeImageField :
- visibleImageFields[visibleImageFields.length - 1]
- , files );
-});
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
deleted file mode 100644
index 619bb1fa3d7a9aef3bc61a8d72b202b939039757..0000000000000000000000000000000000000000
--- a/javascript/edit-attention.js
+++ /dev/null
@@ -1,96 +0,0 @@
-function keyupEditAttention(event){
- let target = event.originalTarget || event.composedPath()[0];
- if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
- if (! (event.metaKey || event.ctrlKey)) return;
-
- let isPlus = event.key == "ArrowUp"
- let isMinus = event.key == "ArrowDown"
- if (!isPlus && !isMinus) return;
-
- let selectionStart = target.selectionStart;
- let selectionEnd = target.selectionEnd;
- let text = target.value;
-
- function selectCurrentParenthesisBlock(OPEN, CLOSE){
- if (selectionStart !== selectionEnd) return false;
-
- // Find opening parenthesis around current cursor
- const before = text.substring(0, selectionStart);
- let beforeParen = before.lastIndexOf(OPEN);
- if (beforeParen == -1) return false;
- let beforeParenClose = before.lastIndexOf(CLOSE);
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
- }
-
- // Find closing parenthesis around current cursor
- const after = text.substring(selectionStart);
- let afterParen = after.indexOf(CLOSE);
- if (afterParen == -1) return false;
- let afterParenOpen = after.indexOf(OPEN);
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(CLOSE, afterParen + 1);
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
- }
- if (beforeParen === -1 || afterParen === -1) return false;
-
- // Set the selection to the text between the parenthesis
- const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
- const lastColon = parenContent.lastIndexOf(":");
- selectionStart = beforeParen + 1;
- selectionEnd = selectionStart + lastColon;
- target.setSelectionRange(selectionStart, selectionEnd);
- return true;
- }
-
- // If the user hasn't selected anything, let's select their current parenthesis block
- if(! selectCurrentParenthesisBlock('<', '>')){
- selectCurrentParenthesisBlock('(', ')')
- }
-
- event.preventDefault();
-
- closeCharacter = ')'
- delta = opts.keyedit_precision_attention
-
- if (selectionStart > 0 && text[selectionStart - 1] == '<'){
- closeCharacter = '>'
- delta = opts.keyedit_precision_extra
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
-
- // do not include spaces at the end
- while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
- selectionEnd -= 1;
- }
- if(selectionStart == selectionEnd){
- return
- }
-
- text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
-
- selectionStart += 1;
- selectionEnd += 1;
- }
-
- end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
- if (isNaN(weight)) return;
-
- weight += isPlus ? delta : -delta;
- weight = parseFloat(weight.toPrecision(12));
- if(String(weight).length == 1) weight += ".0"
-
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
-
- target.focus();
- target.value = text;
- target.selectionStart = selectionStart;
- target.selectionEnd = selectionEnd;
-
- updateInput(target)
-}
-
-addEventListener('keydown', (event) => {
- keyupEditAttention(event);
-});
\ No newline at end of file
diff --git a/javascript/extensions.js b/javascript/extensions.js
deleted file mode 100644
index c593cd2e5701db5a89f9b890bc952722ed5c3bbf..0000000000000000000000000000000000000000
--- a/javascript/extensions.js
+++ /dev/null
@@ -1,49 +0,0 @@
-
-function extensions_apply(_, _){
- var disable = []
- var update = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
-
- if(x.name.startsWith("update_") && x.checked)
- update.push(x.name.substr(7))
- })
-
- restart_reload()
-
- return [JSON.stringify(disable), JSON.stringify(update)]
-}
-
-function extensions_check(){
- var disable = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substr(7))
- })
-
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
- x.innerHTML = "Loading..."
- })
-
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
-
- })
-
- return [id, JSON.stringify(disable)]
-}
-
-function install_extension_from_index(button, url){
- button.disabled = "disabled"
- button.value = "Installing..."
-
- textarea = gradioApp().querySelector('#extension_to_install textarea')
- textarea.value = url
- updateInput(textarea)
-
- gradioApp().querySelector('#install_extension_button').click()
-}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
deleted file mode 100644
index 17bf200047dfa881ddf7f705788271fd52c29ae8..0000000000000000000000000000000000000000
--- a/javascript/extraNetworks.js
+++ /dev/null
@@ -1,107 +0,0 @@
-
-function setupExtraNetworksForTab(tabname){
- gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
-
- var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
- var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
- var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
- var close = gradioApp().getElementById(tabname+'_extra_close')
-
- search.classList.add('search')
- tabs.appendChild(search)
- tabs.appendChild(refresh)
- tabs.appendChild(close)
-
- search.addEventListener("input", function(evt){
- searchTerm = search.value.toLowerCase()
-
- gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
- elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
- })
- });
-}
-
-var activePromptTextarea = {};
-
-function setupExtraNetworks(){
- setupExtraNetworksForTab('txt2img')
- setupExtraNetworksForTab('img2img')
-
- function registerPrompt(tabname, id){
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
-
- if (! activePromptTextarea[tabname]){
- activePromptTextarea[tabname] = textarea
- }
-
- textarea.addEventListener("focus", function(){
- activePromptTextarea[tabname] = textarea;
- });
- }
-
- registerPrompt('txt2img', 'txt2img_prompt')
- registerPrompt('txt2img', 'txt2img_neg_prompt')
- registerPrompt('img2img', 'img2img_prompt')
- registerPrompt('img2img', 'img2img_neg_prompt')
-}
-
-onUiLoaded(setupExtraNetworks)
-
-var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
-var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
-
-function tryToRemoveExtraNetworkFromPrompt(textarea, text){
- var m = text.match(re_extranet)
- if(! m) return false
-
- var partToSearch = m[1]
- var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
- m = found.match(re_extranet);
- if(m[1] == partToSearch){
- replaced = true;
- return ""
- }
- return found;
- })
-
- if(replaced){
- textarea.value = newTextareaText
- return true;
- }
-
- return false
-}
-
-function cardClicked(tabname, textToAdd, allowNegativePrompt){
- var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
-
- if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
- textarea.value = textarea.value + " " + textToAdd
- }
-
- updateInput(textarea)
-}
-
-function saveCardPreview(event, tabname, filename){
- var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
- var button = gradioApp().getElementById(tabname + '_save_preview')
-
- textarea.value = filename
- updateInput(textarea)
-
- button.click()
-
- event.stopPropagation()
- event.preventDefault()
-}
-
-function extraNetworksSearchButton(tabs_id, event){
- searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
- button = event.target
- text = button.classList.contains("search-all") ? "" : button.textContent.trim()
-
- searchTextarea.value = text
- updateInput(searchTextarea)
-}
\ No newline at end of file
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
deleted file mode 100644
index 95f050939b72a8d09d62de8d725caf1e7d15d3c0..0000000000000000000000000000000000000000
--- a/javascript/generationParams.js
+++ /dev/null
@@ -1,33 +0,0 @@
-// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
-
-let txt2img_gallery, img2img_gallery, modal = undefined;
-onUiUpdate(function(){
- if (!txt2img_gallery) {
- txt2img_gallery = attachGalleryListeners("txt2img")
- }
- if (!img2img_gallery) {
- img2img_gallery = attachGalleryListeners("img2img")
- }
- if (!modal) {
- modal = gradioApp().getElementById('lightboxModal')
- modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
- }
-});
-
-let modalObserver = new MutationObserver(function(mutations) {
- mutations.forEach(function(mutationRecord) {
- let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
- if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
- gradioApp().getElementById(selectedTab+"_generation_info_button").click()
- });
-});
-
-function attachGalleryListeners(tab_name) {
- gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
- gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
- gallery?.addEventListener('keydown', (e) => {
- if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
- gradioApp().getElementById(tab_name+"_generation_info_button").click()
- });
- return gallery;
-}
diff --git a/javascript/hints.js b/javascript/hints.js
deleted file mode 100644
index f1199009b181b83713bb1d136e41d4b29e183634..0000000000000000000000000000000000000000
--- a/javascript/hints.js
+++ /dev/null
@@ -1,146 +0,0 @@
-// mouseover tooltips for various UI elements
-
-titles = {
- "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
- "Sampling method": "Which algorithm to use to produce the image",
- "GFPGAN": "Restore low quality faces using GFPGAN neural network",
- "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
- "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
- "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
-
- "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
- "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
- "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
- "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
- "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
- "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
- "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
- "\u{1f4c2}": "Open images output directory",
- "\u{1f4be}": "Save style",
- "\u{1f5d1}": "Clear prompt",
- "\u{1f4cb}": "Apply selected styles to current prompt",
- "\u{1f4d2}": "Paste available values into the field",
- "\u{1f3b4}": "Show extra networks",
-
-
- "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
- "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
-
- "Just resize": "Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.",
- "Crop and resize": "Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.",
- "Resize and fill": "Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.",
-
- "Mask blur": "How much to blur the mask before processing, in pixels.",
- "Masked content": "What to put inside the masked area before processing it with Stable Diffusion.",
- "fill": "fill it with colors of the image",
- "original": "keep whatever was there originally",
- "latent noise": "fill it with latent space noise",
- "latent nothing": "fill it with latent space zeroes",
- "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
-
- "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
- "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
-
- "Skip": "Stop processing current image and continue processing.",
- "Interrupt": "Stop processing images and return any results accumulated so far.",
- "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
-
- "X values": "Separate values for X axis using commas.",
- "Y values": "Separate values for Y axis using commas.",
-
- "None": "Do not do anything special",
- "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
- "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
- "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
-
- "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
- "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
-
- "Tiling": "Produce an image that can be tiled.",
- "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
-
- "Variation seed": "Seed of a different picture to be mixed into the generation.",
- "Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
- "Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
- "Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
-
- "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
-
- "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime], [job_timestamp]; leave empty for default.",
- "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime], [job_timestamp]; leave empty for default.",
- "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
-
- "Loopback": "Process an image, use it as an input, repeat.",
- "Loops": "How many times to repeat processing an image and using it as input for the next iteration",
-
- "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
- "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
- "Apply style": "Insert selected styles into prompt fields",
- "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.",
-
- "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
- "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
-
- "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
-
- "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
- "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
-
- "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
- "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
-
- "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
-
- "Weighted sum": "Result = A * (1 - M) + B * M",
- "Add difference": "Result = A + (B - C) * M",
- "No interpolation": "Result = A",
-
- "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
- "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
-
- "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
-
- "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.",
- "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.",
-
- "Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
- "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
- "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
- "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
- "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
- "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
- "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
- "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
-}
-
-
-onUiUpdate(function(){
- gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
- tooltip = titles[span.textContent];
-
- if(!tooltip){
- tooltip = titles[span.value];
- }
-
- if(!tooltip){
- for (const c of span.classList) {
- if (c in titles) {
- tooltip = titles[c];
- break;
- }
- }
- }
-
- if(tooltip){
- span.title = tooltip;
- }
- })
-
- gradioApp().querySelectorAll('select').forEach(function(select){
- if (select.onchange != null) return;
-
- select.onchange = function(){
- select.title = titles[select.value] || "";
- }
- })
-})
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
deleted file mode 100644
index 0629475f894fb5c3facdc1e0a09bc770350ee904..0000000000000000000000000000000000000000
--- a/javascript/hires_fix.js
+++ /dev/null
@@ -1,22 +0,0 @@
-
-function setInactive(elem, inactive){
- if(inactive){
- elem.classList.add('inactive')
- } else{
- elem.classList.remove('inactive')
- }
-}
-
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
- hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
- hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
-
- gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
-
- setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
- setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
- setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
-
- return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
-}
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
deleted file mode 100644
index 9fe7a60309c95b4921360fb09d5bee2b2bd2a73c..0000000000000000000000000000000000000000
--- a/javascript/imageMaskFix.js
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
- * @see https://github.com/gradio-app/gradio/issues/1721
- */
-window.addEventListener( 'resize', () => imageMaskResize());
-function imageMaskResize() {
- const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
- if ( ! canvases.length ) {
- canvases_fixed = false;
- window.removeEventListener( 'resize', imageMaskResize );
- return;
- }
-
- const wrapper = canvases[0].closest('.touch-none');
- const previewImage = wrapper.previousElementSibling;
-
- if ( ! previewImage.complete ) {
- previewImage.addEventListener( 'load', () => imageMaskResize());
- return;
- }
-
- const w = previewImage.width;
- const h = previewImage.height;
- const nw = previewImage.naturalWidth;
- const nh = previewImage.naturalHeight;
- const portrait = nh > nw;
- const factor = portrait;
-
- const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
- const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
-
- wrapper.style.width = `${wW}px`;
- wrapper.style.height = `${wH}px`;
- wrapper.style.left = `0px`;
- wrapper.style.top = `0px`;
-
- canvases.forEach( c => {
- c.style.width = c.style.height = '';
- c.style.maxWidth = '100%';
- c.style.maxHeight = '100%';
- c.style.objectFit = 'contain';
- });
- }
-
- onUiUpdate(() => imageMaskResize());
diff --git a/javascript/imageParams.js b/javascript/imageParams.js
deleted file mode 100644
index 67404a89ba6084a065ab5ac188e01ed29952113b..0000000000000000000000000000000000000000
--- a/javascript/imageParams.js
+++ /dev/null
@@ -1,19 +0,0 @@
-window.onload = (function(){
- window.addEventListener('drop', e => {
- const target = e.composedPath()[0];
- const idx = selected_gallery_index();
- if (target.placeholder.indexOf("Prompt") == -1) return;
-
- let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
-
- e.stopPropagation();
- e.preventDefault();
- const imgParent = gradioApp().getElementById(prompt_target);
- const files = e.dataTransfer.files;
- const fileInput = imgParent.querySelector('input[type="file"]');
- if ( fileInput ) {
- fileInput.files = files;
- fileInput.dispatchEvent(new Event('change'));
- }
- });
-});
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
deleted file mode 100644
index aac2ee82383881bd9d59a264d2cd2c823c2187c4..0000000000000000000000000000000000000000
--- a/javascript/imageviewer.js
+++ /dev/null
@@ -1,285 +0,0 @@
-// A full size 'lightbox' preview modal shown when left clicking on gallery previews
-function closeModal() {
- gradioApp().getElementById("lightboxModal").style.display = "none";
-}
-
-function showModal(event) {
- const source = event.target || event.srcElement;
- const modalImage = gradioApp().getElementById("modalImage")
- const lb = gradioApp().getElementById("lightboxModal")
- modalImage.src = source.src
- if (modalImage.style.display === 'none') {
- lb.style.setProperty('background-image', 'url(' + source.src + ')');
- }
- lb.style.display = "block";
- lb.focus()
-
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
- // show the save button in modal only on txt2img or img2img tabs
- if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
- gradioApp().getElementById("modal_save").style.display = "inline"
- } else {
- gradioApp().getElementById("modal_save").style.display = "none"
- }
- event.stopPropagation()
-}
-
-function negmod(n, m) {
- return ((n % m) + m) % m;
-}
-
-function updateOnBackgroundChange() {
- const modalImage = gradioApp().getElementById("modalImage")
- if (modalImage && modalImage.offsetParent) {
- let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
- let currentButton = null
- allcurrentButtons.forEach(function(elem) {
- if (elem.parentElement.offsetParent) {
- currentButton = elem;
- }
- })
-
- if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
- modalImage.src = currentButton.children[0].src;
- if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
- }
- }
- }
-}
-
-function modalImageSwitch(offset) {
- var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
- var galleryButtons = []
- allgalleryButtons.forEach(function(elem) {
- if (elem.parentElement.offsetParent) {
- galleryButtons.push(elem);
- }
- })
-
- if (galleryButtons.length > 1) {
- var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
- var currentButton = null
- allcurrentButtons.forEach(function(elem) {
- if (elem.parentElement.offsetParent) {
- currentButton = elem;
- }
- })
-
- var result = -1
- galleryButtons.forEach(function(v, i) {
- if (v == currentButton) {
- result = i
- }
- })
-
- if (result != -1) {
- nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
- nextButton.click()
- const modalImage = gradioApp().getElementById("modalImage");
- const modal = gradioApp().getElementById("lightboxModal");
- modalImage.src = nextButton.children[0].src;
- if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
- }
- setTimeout(function() {
- modal.focus()
- }, 10)
- }
- }
-}
-
-function saveImage(){
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
- const saveTxt2Img = "save_txt2img"
- const saveImg2Img = "save_img2img"
- if (tabTxt2Img.style.display != "none") {
- gradioApp().getElementById(saveTxt2Img).click()
- } else if (tabImg2Img.style.display != "none") {
- gradioApp().getElementById(saveImg2Img).click()
- } else {
- console.error("missing implementation for saving modal of this type")
- }
-}
-
-function modalSaveImage(event) {
- saveImage()
- event.stopPropagation()
-}
-
-function modalNextImage(event) {
- modalImageSwitch(1)
- event.stopPropagation()
-}
-
-function modalPrevImage(event) {
- modalImageSwitch(-1)
- event.stopPropagation()
-}
-
-function modalKeyHandler(event) {
- switch (event.key) {
- case "s":
- saveImage()
- break;
- case "ArrowLeft":
- modalPrevImage(event)
- break;
- case "ArrowRight":
- modalNextImage(event)
- break;
- case "Escape":
- closeModal();
- break;
- }
-}
-
-function showGalleryImage() {
- setTimeout(function() {
- fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
-
- if (fullImg_preview != null) {
- fullImg_preview.forEach(function function_name(e) {
- if (e.dataset.modded)
- return;
- e.dataset.modded = true;
- if(e && e.parentElement.tagName == 'DIV'){
- e.style.cursor='pointer'
- e.style.userSelect='none'
-
- var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
-
- // For Firefox, listening on click first switched to next image then shows the lightbox.
- // If you know how to fix this without switching to mousedown event, please.
- // For other browsers the event is click to make it possiblr to drag picture.
- var event = isFirefox ? 'mousedown' : 'click'
-
- e.addEventListener(event, function (evt) {
- if(!opts.js_modal_lightbox || evt.button != 0) return;
- modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
- evt.preventDefault()
- showModal(evt)
- }, true);
- }
- });
- }
-
- }, 100);
-}
-
-function modalZoomSet(modalImage, enable) {
- if (enable) {
- modalImage.classList.add('modalImageFullscreen');
- } else {
- modalImage.classList.remove('modalImageFullscreen');
- }
-}
-
-function modalZoomToggle(event) {
- modalImage = gradioApp().getElementById("modalImage");
- modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
- event.stopPropagation()
-}
-
-function modalTileImageToggle(event) {
- const modalImage = gradioApp().getElementById("modalImage");
- const modal = gradioApp().getElementById("lightboxModal");
- const isTiling = modalImage.style.display === 'none';
- if (isTiling) {
- modalImage.style.display = 'block';
- modal.style.setProperty('background-image', 'none')
- } else {
- modalImage.style.display = 'none';
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
- }
-
- event.stopPropagation()
-}
-
-function galleryImageHandler(e) {
- if (e && e.parentElement.tagName == 'BUTTON') {
- e.onclick = showGalleryImage;
- }
-}
-
-onUiUpdate(function() {
- fullImg_preview = gradioApp().querySelectorAll('img.w-full')
- if (fullImg_preview != null) {
- fullImg_preview.forEach(galleryImageHandler);
- }
- updateOnBackgroundChange();
-})
-
-document.addEventListener("DOMContentLoaded", function() {
- const modalFragment = document.createDocumentFragment();
- const modal = document.createElement('div')
- modal.onclick = closeModal;
- modal.id = "lightboxModal";
- modal.tabIndex = 0
- modal.addEventListener('keydown', modalKeyHandler, true)
-
- const modalControls = document.createElement('div')
- modalControls.className = 'modalControls gradio-container';
- modal.append(modalControls);
-
- const modalZoom = document.createElement('span')
- modalZoom.className = 'modalZoom cursor';
- modalZoom.innerHTML = '⤡'
- modalZoom.addEventListener('click', modalZoomToggle, true)
- modalZoom.title = "Toggle zoomed view";
- modalControls.appendChild(modalZoom)
-
- const modalTileImage = document.createElement('span')
- modalTileImage.className = 'modalTileImage cursor';
- modalTileImage.innerHTML = '⊞'
- modalTileImage.addEventListener('click', modalTileImageToggle, true)
- modalTileImage.title = "Preview tiling";
- modalControls.appendChild(modalTileImage)
-
- const modalSave = document.createElement("span")
- modalSave.className = "modalSave cursor"
- modalSave.id = "modal_save"
- modalSave.innerHTML = "🖫"
- modalSave.addEventListener("click", modalSaveImage, true)
- modalSave.title = "Save Image(s)"
- modalControls.appendChild(modalSave)
-
- const modalClose = document.createElement('span')
- modalClose.className = 'modalClose cursor';
- modalClose.innerHTML = '×'
- modalClose.onclick = closeModal;
- modalClose.title = "Close image viewer";
- modalControls.appendChild(modalClose)
-
- const modalImage = document.createElement('img')
- modalImage.id = 'modalImage';
- modalImage.onclick = closeModal;
- modalImage.tabIndex = 0
- modalImage.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalImage)
-
- const modalPrev = document.createElement('a')
- modalPrev.className = 'modalPrev';
- modalPrev.innerHTML = '❮'
- modalPrev.tabIndex = 0
- modalPrev.addEventListener('click', modalPrevImage, true);
- modalPrev.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalPrev)
-
- const modalNext = document.createElement('a')
- modalNext.className = 'modalNext';
- modalNext.innerHTML = '❯'
- modalNext.tabIndex = 0
- modalNext.addEventListener('click', modalNextImage, true);
- modalNext.addEventListener('keydown', modalKeyHandler, true)
-
- modal.appendChild(modalNext)
-
-
- gradioApp().getRootNode().appendChild(modal)
-
- document.body.appendChild(modalFragment);
-
-});
diff --git a/javascript/localization.js b/javascript/localization.js
deleted file mode 100644
index 1a5a1dbb699b8307b7d3af0d93b5a5c1260e4ddb..0000000000000000000000000000000000000000
--- a/javascript/localization.js
+++ /dev/null
@@ -1,165 +0,0 @@
-
-// localization = {} -- the dict with translations is created by the backend
-
-ignore_ids_for_localization={
- setting_sd_hypernetwork: 'OPTION',
- setting_sd_model_checkpoint: 'OPTION',
- setting_realesrgan_enabled_models: 'OPTION',
- modelmerger_primary_model_name: 'OPTION',
- modelmerger_secondary_model_name: 'OPTION',
- modelmerger_tertiary_model_name: 'OPTION',
- train_embedding: 'OPTION',
- train_hypernetwork: 'OPTION',
- txt2img_styles: 'OPTION',
- img2img_styles: 'OPTION',
- setting_random_artist_categories: 'SPAN',
- setting_face_restoration_model: 'SPAN',
- setting_realesrgan_enabled_models: 'SPAN',
- extras_upscaler_1: 'SPAN',
- extras_upscaler_2: 'SPAN',
-}
-
-re_num = /^[\.\d]+$/
-re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
-
-original_lines = {}
-translated_lines = {}
-
-function textNodesUnder(el){
- var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
- while(n=walk.nextNode()) a.push(n);
- return a;
-}
-
-function canBeTranslated(node, text){
- if(! text) return false;
- if(! node.parentElement) return false;
-
- parentType = node.parentElement.nodeName
- if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
-
- if (parentType=='OPTION' || parentType=='SPAN'){
- pnode = node
- for(var level=0; level<4; level++){
- pnode = pnode.parentElement
- if(! pnode) break;
-
- if(ignore_ids_for_localization[pnode.id] == parentType) return false;
- }
- }
-
- if(re_num.test(text)) return false;
- if(re_emoji.test(text)) return false;
- return true
-}
-
-function getTranslation(text){
- if(! text) return undefined
-
- if(translated_lines[text] === undefined){
- original_lines[text] = 1
- }
-
- tl = localization[text]
- if(tl !== undefined){
- translated_lines[tl] = 1
- }
-
- return tl
-}
-
-function processTextNode(node){
- text = node.textContent.trim()
-
- if(! canBeTranslated(node, text)) return
-
- tl = getTranslation(text)
- if(tl !== undefined){
- node.textContent = tl
- }
-}
-
-function processNode(node){
- if(node.nodeType == 3){
- processTextNode(node)
- return
- }
-
- if(node.title){
- tl = getTranslation(node.title)
- if(tl !== undefined){
- node.title = tl
- }
- }
-
- if(node.placeholder){
- tl = getTranslation(node.placeholder)
- if(tl !== undefined){
- node.placeholder = tl
- }
- }
-
- textNodesUnder(node).forEach(function(node){
- processTextNode(node)
- })
-}
-
-function dumpTranslations(){
- dumped = {}
- if (localization.rtl) {
- dumped.rtl = true
- }
-
- Object.keys(original_lines).forEach(function(text){
- if(dumped[text] !== undefined) return
-
- dumped[text] = localization[text] || text
- })
-
- return dumped
-}
-
-onUiUpdate(function(m){
- m.forEach(function(mutation){
- mutation.addedNodes.forEach(function(node){
- processNode(node)
- })
- });
-})
-
-
-document.addEventListener("DOMContentLoaded", function() {
- processNode(gradioApp())
-
- if (localization.rtl) { // if the language is from right to left,
- (new MutationObserver((mutations, observer) => { // wait for the style to load
- mutations.forEach(mutation => {
- mutation.addedNodes.forEach(node => {
- if (node.tagName === 'STYLE') {
- observer.disconnect();
-
- for (const x of node.sheet.rules) { // find all rtl media rules
- if (Array.from(x.media || []).includes('rtl')) {
- x.media.appendMedium('all'); // enable them
- }
- }
- }
- })
- });
- })).observe(gradioApp(), { childList: true });
- }
-})
-
-function download_localization() {
- text = JSON.stringify(dumpTranslations(), null, 4)
-
- var element = document.createElement('a');
- element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
- element.setAttribute('download', "localization.json");
- element.style.display = 'none';
- document.body.appendChild(element);
-
- element.click();
-
- document.body.removeChild(element);
-}
diff --git a/javascript/notification.js b/javascript/notification.js
deleted file mode 100644
index 040a3afac2019fe2d3532122b8317560d5935814..0000000000000000000000000000000000000000
--- a/javascript/notification.js
+++ /dev/null
@@ -1,49 +0,0 @@
-// Monitors the gallery and sends a browser notification when the leading image is new.
-
-let lastHeadImg = null;
-
-notificationButton = null
-
-onUiUpdate(function(){
- if(notificationButton == null){
- notificationButton = gradioApp().getElementById('request_notifications')
-
- if(notificationButton != null){
- notificationButton.addEventListener('click', function (evt) {
- Notification.requestPermission();
- },true);
- }
- }
-
- const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
-
- if (galleryPreviews == null) return;
-
- const headImg = galleryPreviews[0]?.src;
-
- if (headImg == null || headImg == lastHeadImg) return;
-
- lastHeadImg = headImg;
-
- // play notification sound if available
- gradioApp().querySelector('#audio_notification audio')?.play();
-
- if (document.hasFocus()) return;
-
- // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
- const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
-
- const notification = new Notification(
- 'Stable Diffusion',
- {
- body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
- icon: headImg,
- image: headImg,
- }
- );
-
- notification.onclick = function(_){
- parent.focus();
- this.close();
- };
-});
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
deleted file mode 100644
index ff6d757bae88f5f622767376e5315b9acf8271cd..0000000000000000000000000000000000000000
--- a/javascript/progressbar.js
+++ /dev/null
@@ -1,243 +0,0 @@
-// code related to showing and updating progressbar shown as the image is being made
-
-
-galleries = {}
-storedGallerySelections = {}
-galleryObservers = {}
-
-function rememberGallerySelection(id_gallery){
- storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
-}
-
-function getGallerySelectedIndex(id_gallery){
- let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
- let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
-
- let currentlySelectedIndex = -1
- galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
-
- return currentlySelectedIndex
-}
-
-// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
-function check_gallery(id_gallery){
- let gallery = gradioApp().getElementById(id_gallery)
- // if gallery has no change, no need to setting up observer again.
- if (gallery && galleries[id_gallery] !== gallery){
- galleries[id_gallery] = gallery;
- if(galleryObservers[id_gallery]){
- galleryObservers[id_gallery].disconnect();
- }
-
- storedGallerySelections[id_gallery] = -1
-
- galleryObservers[id_gallery] = new MutationObserver(function (){
- let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
- let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
- let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
- prevSelectedIndex = storedGallerySelections[id_gallery]
- storedGallerySelections[id_gallery] = -1
-
- if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
- // automatically re-open previously selected index (if exists)
- activeElement = gradioApp().activeElement;
- let scrollX = window.scrollX;
- let scrollY = window.scrollY;
-
- galleryButtons[prevSelectedIndex].click();
- showGalleryImage();
-
- // When the gallery button is clicked, it gains focus and scrolls itself into view
- // We need to scroll back to the previous position
- setTimeout(function (){
- window.scrollTo(scrollX, scrollY);
- }, 50);
-
- if(activeElement){
- // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
- // if someone has a better solution please by all means
- setTimeout(function (){
- activeElement.focus({
- preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
- })
- }, 1);
- }
- }
- })
- galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
- }
-}
-
-onUiUpdate(function(){
- check_gallery('txt2img_gallery')
- check_gallery('img2img_gallery')
-})
-
-function request(url, data, handler, errorHandler){
- var xhr = new XMLHttpRequest();
- var url = url;
- xhr.open("POST", url, true);
- xhr.setRequestHeader("Content-Type", "application/json");
- xhr.onreadystatechange = function () {
- if (xhr.readyState === 4) {
- if (xhr.status === 200) {
- try {
- var js = JSON.parse(xhr.responseText);
- handler(js)
- } catch (error) {
- console.error(error);
- errorHandler()
- }
- } else{
- errorHandler()
- }
- }
- };
- var js = JSON.stringify(data);
- xhr.send(js);
-}
-
-function pad2(x){
- return x<10 ? '0'+x : x
-}
-
-function formatTime(secs){
- if(secs > 3600){
- return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
- } else if(secs > 60){
- return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
- } else{
- return Math.floor(secs) + "s"
- }
-}
-
-function setTitle(progress){
- var title = 'Stable Diffusion'
-
- if(opts.show_progress_in_title && progress){
- title = '[' + progress.trim() + '] ' + title;
- }
-
- if(document.title != title){
- document.title = title;
- }
-}
-
-
-function randomId(){
- return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
-}
-
-// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
-// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
-// calls onProgress every time there is a progress update
-function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
- var dateStart = new Date()
- var wasEverActive = false
- var parentProgressbar = progressbarContainer.parentNode
- var parentGallery = gallery ? gallery.parentNode : null
-
- var divProgress = document.createElement('div')
- divProgress.className='progressDiv'
- divProgress.style.display = opts.show_progressbar ? "" : "none"
- var divInner = document.createElement('div')
- divInner.className='progress'
-
- divProgress.appendChild(divInner)
- parentProgressbar.insertBefore(divProgress, progressbarContainer)
-
- if(parentGallery){
- var livePreview = document.createElement('div')
- livePreview.className='livePreview'
- parentGallery.insertBefore(livePreview, gallery)
- }
-
- var removeProgressBar = function(){
- setTitle("")
- parentProgressbar.removeChild(divProgress)
- if(parentGallery) parentGallery.removeChild(livePreview)
- atEnd()
- }
-
- var fun = function(id_task, id_live_preview){
- request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
- if(res.completed){
- removeProgressBar()
- return
- }
-
- var rect = progressbarContainer.getBoundingClientRect()
-
- if(rect.width){
- divProgress.style.width = rect.width + "px";
- }
-
- progressText = ""
-
- divInner.style.width = ((res.progress || 0) * 100.0) + '%'
- divInner.style.background = res.progress ? "" : "transparent"
-
- if(res.progress > 0){
- progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
- }
-
- if(res.eta){
- progressText += " ETA: " + formatTime(res.eta)
- }
-
-
- setTitle(progressText)
-
- if(res.textinfo && res.textinfo.indexOf("\n") == -1){
- progressText = res.textinfo + " " + progressText
- }
-
- divInner.textContent = progressText
-
- var elapsedFromStart = (new Date() - dateStart) / 1000
-
- if(res.active) wasEverActive = true;
-
- if(! res.active && wasEverActive){
- removeProgressBar()
- return
- }
-
- if(elapsedFromStart > 5 && !res.queued && !res.active){
- removeProgressBar()
- return
- }
-
-
- if(res.live_preview && gallery){
- var rect = gallery.getBoundingClientRect()
- if(rect.width){
- livePreview.style.width = rect.width + "px"
- livePreview.style.height = rect.height + "px"
- }
-
- var img = new Image();
- img.onload = function() {
- livePreview.appendChild(img)
- if(livePreview.childElementCount > 2){
- livePreview.removeChild(livePreview.firstElementChild)
- }
- }
- img.src = res.live_preview;
- }
-
-
- if(onProgress){
- onProgress(res)
- }
-
- setTimeout(() => {
- fun(id_task, res.id_live_preview);
- }, opts.live_preview_refresh_period || 500)
- }, function(){
- removeProgressBar()
- })
- }
-
- fun(id_task, 0)
-}
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
deleted file mode 100644
index 0354b860ca96c8e9d473e3223d7cdda9e978a879..0000000000000000000000000000000000000000
--- a/javascript/textualInversion.js
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-function start_training_textual_inversion(){
- gradioApp().querySelector('#ti_error').innerHTML=''
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
- gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
- })
-
- var res = args_to_array(arguments)
-
- res[0] = id
-
- return res
-}
diff --git a/javascript/ui.js b/javascript/ui.js
deleted file mode 100644
index b7a8268a8fcdf9821cb3af31efea9e0283da1bfe..0000000000000000000000000000000000000000
--- a/javascript/ui.js
+++ /dev/null
@@ -1,338 +0,0 @@
-// various functions for interaction with ui.py not large enough to warrant putting them in separate files
-
-function set_theme(theme){
- gradioURL = window.location.href
- if (!gradioURL.includes('?__theme=')) {
- window.location.replace(gradioURL + '?__theme=' + theme);
- }
-}
-
-function selected_gallery_index(){
- var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
- var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
-
- var result = -1
- buttons.forEach(function(v, i){ if(v==button) { result = i } })
-
- return result
-}
-
-function extract_image_from_gallery(gallery){
- if(gallery.length == 1){
- return [gallery[0]]
- }
-
- index = selected_gallery_index()
-
- if (index < 0 || index >= gallery.length){
- return [null]
- }
-
- return [gallery[index]];
-}
-
-function args_to_array(args){
- res = []
- for(var i=0;i label > textarea");
-
- if(counter.parentElement == prompt.parentElement){
- return
- }
-
- prompt.parentElement.insertBefore(counter, prompt)
- counter.classList.add("token-counter")
- prompt.parentElement.style.position = "relative"
-
- promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
- textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
- }
-
- registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
- registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
- registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
- registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
-
- show_all_pages = gradioApp().getElementById('settings_show_all_pages')
- settings_tabs = gradioApp().querySelector('#settings div')
- if(show_all_pages && settings_tabs){
- settings_tabs.appendChild(show_all_pages)
- show_all_pages.onclick = function(){
- gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
- elem.style.display = "block";
- })
- }
- }
-})
-
-onOptionsChanged(function(){
- elem = gradioApp().getElementById('sd_checkpoint_hash')
- sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
- shorthash = sd_checkpoint_hash.substr(0,10)
-
- if(elem && elem.textContent != shorthash){
- elem.textContent = shorthash
- elem.title = sd_checkpoint_hash
- elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
- }
-})
-
-let txt2img_textarea, img2img_textarea = undefined;
-let wait_time = 800
-let token_timeouts = {};
-
-function update_txt2img_tokens(...args) {
- update_token_counter("txt2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
-}
-
-function update_img2img_tokens(...args) {
- update_token_counter("img2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
-}
-
-function update_token_counter(button_id) {
- if (token_timeouts[button_id])
- clearTimeout(token_timeouts[button_id]);
- token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
-}
-
-function restart_reload(){
- document.body.innerHTML='Reloading... ';
- setTimeout(function(){location.reload()},2000)
-
- return []
-}
-
-// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
-// will only visible on web page and not sent to python.
-function updateInput(target){
- let e = new Event("input", { bubbles: true })
- Object.defineProperty(e, "target", {value: target})
- target.dispatchEvent(e);
-}
-
-
-var desiredCheckpointName = null;
-function selectCheckpoint(name){
- desiredCheckpointName = name;
- gradioApp().getElementById('change_checkpoint').click()
-}
diff --git a/launch.py b/launch.py
deleted file mode 100644
index a68bb3a91f88a3af64d6d0c7c12404e15283fdc0..0000000000000000000000000000000000000000
--- a/launch.py
+++ /dev/null
@@ -1,361 +0,0 @@
-# this scripts installs necessary requirements and launches main program in webui.py
-import subprocess
-import os
-import sys
-import importlib.util
-import shlex
-import platform
-import argparse
-import json
-
-dir_repos = "repositories"
-dir_extensions = "extensions"
-python = sys.executable
-git = os.environ.get('GIT', "git")
-index_url = os.environ.get('INDEX_URL', "")
-stored_commit_hash = None
-skip_install = False
-
-
-def check_python_version():
- is_windows = platform.system() == "Windows"
- major = sys.version_info.major
- minor = sys.version_info.minor
- micro = sys.version_info.micro
-
- if is_windows:
- supported_minors = [10]
- else:
- supported_minors = [7, 8, 9, 10, 11]
-
- if not (major == 3 and minor in supported_minors):
- import modules.errors
-
- modules.errors.print_error_explanation(f"""
-INCOMPATIBLE PYTHON VERSION
-
-This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
-If you encounter an error with "RuntimeError: Couldn't install torch." message,
-or any other error regarding unsuccessful package (library) installation,
-please downgrade (or upgrade) to the latest version of 3.10 Python
-and delete current Python and "venv" folder in WebUI's directory.
-
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
-
-{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
-
-Use --skip-python-version-check to suppress this warning.
-""")
-
-
-def commit_hash():
- global stored_commit_hash
-
- if stored_commit_hash is not None:
- return stored_commit_hash
-
- try:
- stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
- except Exception:
- stored_commit_hash = ""
-
- return stored_commit_hash
-
-
-def extract_arg(args, name):
- return [x for x in args if x != name], name in args
-
-
-def extract_opt(args, name):
- opt = None
- is_present = False
- if name in args:
- is_present = True
- idx = args.index(name)
- del args[idx]
- if idx < len(args) and args[idx][0] != "-":
- opt = args[idx]
- del args[idx]
- return args, is_present, opt
-
-
-def run(command, desc=None, errdesc=None, custom_env=None, live=False):
- if desc is not None:
- print(desc)
-
- if live:
- result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
- if result.returncode != 0:
- raise RuntimeError(f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}""")
-
- return ""
-
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
-
- if result.returncode != 0:
-
- message = f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}
-stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
-stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''}
-"""
- raise RuntimeError(message)
-
- return result.stdout.decode(encoding="utf8", errors="ignore")
-
-
-def check_run(command):
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
- return result.returncode == 0
-
-
-def is_installed(package):
- try:
- spec = importlib.util.find_spec(package)
- except ModuleNotFoundError:
- return False
-
- return spec is not None
-
-
-def repo_dir(name):
- return os.path.join(dir_repos, name)
-
-
-def run_python(code, desc=None, errdesc=None):
- return run(f'"{python}" -c "{code}"', desc, errdesc)
-
-
-def run_pip(args, desc=None):
- if skip_install:
- return
-
- index_url_line = f' --index-url {index_url}' if index_url != '' else ''
- return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
-
-
-def check_run_python(code):
- return check_run(f'"{python}" -c "{code}"')
-
-
-def git_clone(url, dir, name, commithash=None):
- # TODO clone into temporary dir and move if successful
-
- if os.path.exists(dir):
- if commithash is None:
- return
-
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
- if current_hash == commithash:
- return
-
- run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
- return
-
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
-
- if commithash is not None:
- run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
-
-
-def version_check(commit):
- try:
- import requests
- commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
- if commit != "" and commits['commit']['sha'] != commit:
- print("--------------------------------------------------------")
- print("| You are not up to date with the most recent release. |")
- print("| Consider running `git pull` to update. |")
- print("--------------------------------------------------------")
- elif commits['commit']['sha'] == commit:
- print("You are up to date with the most recent release.")
- else:
- print("Not a git clone, can't perform version check.")
- except Exception as e:
- print("version check failed", e)
-
-
-def run_extension_installer(extension_dir):
- path_installer = os.path.join(extension_dir, "install.py")
- if not os.path.isfile(path_installer):
- return
-
- try:
- env = os.environ.copy()
- env['PYTHONPATH'] = os.path.abspath(".")
-
- print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
- except Exception as e:
- print(e, file=sys.stderr)
-
-
-def list_extensions(settings_file):
- settings = {}
-
- try:
- if os.path.isfile(settings_file):
- with open(settings_file, "r", encoding="utf8") as file:
- settings = json.load(file)
- except Exception as e:
- print(e, file=sys.stderr)
-
- disabled_extensions = set(settings.get('disabled_extensions', []))
-
- return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
-
-
-def run_extensions_installers(settings_file):
- if not os.path.isdir(dir_extensions):
- return
-
- for dirname_extension in list_extensions(settings_file):
- run_extension_installer(os.path.join(dir_extensions, dirname_extension))
-
-
-def prepare_environment():
- global skip_install
-
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
- requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
-
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
- gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
- clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
- openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
-
- stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
- taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
- k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
- blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
-
- stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
- taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
- k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
- blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-
- sys.argv += shlex.split(commandline_args)
-
- parser = argparse.ArgumentParser(add_help=False)
- parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
- args, _ = parser.parse_known_args(sys.argv)
-
- sys.argv, _ = extract_arg(sys.argv, '-f')
- sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
- sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
- sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
- sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
- sys.argv, update_check = extract_arg(sys.argv, '--update-check')
- sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
- sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
- xformers = '--xformers' in sys.argv
- ngrok = '--ngrok' in sys.argv
-
- if not skip_python_version_check:
- check_python_version()
-
- commit = commit_hash()
-
- print(f"Python {sys.version}")
- print(f"Commit hash: {commit}")
-
- if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
-
- if not skip_torch_cuda_test:
- run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
-
- if not is_installed("gfpgan"):
- run_pip(f"install {gfpgan_package}", "gfpgan")
-
- if not is_installed("clip"):
- run_pip(f"install {clip_package}", "clip")
-
- if not is_installed("open_clip"):
- run_pip(f"install {openclip_package}", "open_clip")
-
- if (not is_installed("xformers") or reinstall_xformers) and xformers:
- if platform.system() == "Windows":
- if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
- else:
- print("Installation of xformers is not supported in this version of Python.")
- print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
- if not is_installed("xformers"):
- exit(0)
- elif platform.system() == "Linux":
- run_pip(f"install {xformers_package}", "xformers")
-
- if not is_installed("pyngrok") and ngrok:
- run_pip("install pyngrok", "ngrok")
-
- os.makedirs(dir_repos, exist_ok=True)
-
- git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
- git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
- git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
- git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
-
- if not is_installed("lpips"):
- run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
-
- run_pip(f"install -r {requirements_file}", "requirements for Web UI")
-
- run_extensions_installers(settings_file=args.ui_settings_file)
-
- if update_check:
- version_check(commit)
-
- if "--exit" in sys.argv:
- print("Exiting because of --exit argument")
- exit(0)
-
- if run_tests:
- exitcode = tests(test_dir)
- exit(exitcode)
-
-
-def tests(test_dir):
- if "--api" not in sys.argv:
- sys.argv.append("--api")
- if "--ckpt" not in sys.argv:
- sys.argv.append("--ckpt")
- sys.argv.append("./test/test_files/empty.pt")
- if "--skip-torch-cuda-test" not in sys.argv:
- sys.argv.append("--skip-torch-cuda-test")
- if "--disable-nan-check" not in sys.argv:
- sys.argv.append("--disable-nan-check")
-
- print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
-
- os.environ['COMMANDLINE_ARGS'] = ""
- with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
- proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
-
- import test.server_poll
- exitcode = test.server_poll.run_tests(proc, test_dir)
-
- print(f"Stopping Web UI process with id {proc.pid}")
- proc.kill()
- return exitcode
-
-
-def start():
- print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
- import webui
- if '--nowebui' in sys.argv:
- webui.api_only()
- else:
- webui.webui()
-
-
-if __name__ == "__main__":
- prepare_environment()
- start()
diff --git a/localizations/Put localization files here.txt b/localizations/Put localization files here.txt
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt b/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/models/Stable-diffusion/model.ckpt b/models/Stable-diffusion/model.ckpt
deleted file mode 100644
index bdc7ea15824c46103a08ba376b47eebf09f36895..0000000000000000000000000000000000000000
--- a/models/Stable-diffusion/model.ckpt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4c86efd062ea772972dd068c3870773f1fd9a37ba9a288ee79ec848448ce7785
-size 2132791380
diff --git a/models/VAE-approx/model.pt b/models/VAE-approx/model.pt
deleted file mode 100644
index 09c6b8f7fda5e15495c6203ca323d6573745d0af..0000000000000000000000000000000000000000
--- a/models/VAE-approx/model.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4f88c9078bb2238cdd0d8864671dd33e3f42e091e41f08903f3c15e4a54a9b39
-size 213777
diff --git a/models/VAE/Put VAE here.txt b/models/VAE/Put VAE here.txt
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/models/cldm_v15.yaml b/models/cldm_v15.yaml
deleted file mode 100644
index fde1825577acd46dc90d8d7c6730e22be762fccb..0000000000000000000000000000000000000000
--- a/models/cldm_v15.yaml
+++ /dev/null
@@ -1,79 +0,0 @@
-model:
- target: cldm.cldm.ControlLDM
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: "jpg"
- cond_stage_key: "txt"
- control_key: "hint"
- image_size: 64
- channels: 4
- cond_stage_trainable: false
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
- only_mid_control: False
-
- control_stage_config:
- target: cldm.cldm.ControlNet
- params:
- image_size: 32 # unused
- in_channels: 4
- hint_channels: 3
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: True
- legacy: False
-
- unet_config:
- target: cldm.cldm.ControlledUnetModel
- params:
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: True
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/models/cldm_v21.yaml b/models/cldm_v21.yaml
deleted file mode 100644
index fc65193647e476e108fce5977f11250d55919106..0000000000000000000000000000000000000000
--- a/models/cldm_v21.yaml
+++ /dev/null
@@ -1,85 +0,0 @@
-model:
- target: cldm.cldm.ControlLDM
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: "jpg"
- cond_stage_key: "txt"
- control_key: "hint"
- image_size: 64
- channels: 4
- cond_stage_trainable: false
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
- only_mid_control: False
-
- control_stage_config:
- target: cldm.cldm.ControlNet
- params:
- use_checkpoint: True
- image_size: 32 # unused
- in_channels: 4
- hint_channels: 3
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- legacy: False
-
- unet_config:
- target: cldm.cldm.ControlledUnetModel
- params:
- use_checkpoint: True
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
diff --git a/models/deepbooru/Put your deepbooru release project folder here.txt b/models/deepbooru/Put your deepbooru release project folder here.txt
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/models/image_adapter_v14.yaml b/models/image_adapter_v14.yaml
deleted file mode 100644
index 439d33cc53a349c9b8c1a0091cbd3643359216d5..0000000000000000000000000000000000000000
--- a/models/image_adapter_v14.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-model:
- target: tencentarc.t21_adapter
- params:
- channels: [320, 640, 1280, 1280]
- nums_rb: 2
- ksize: 1
- sk: true
- cin: 192
- use_conv: false
\ No newline at end of file
diff --git a/models/sketch_adapter_v14.yaml b/models/sketch_adapter_v14.yaml
deleted file mode 100644
index 686c5f172bf941ffaaee58b912245d6ffb36f4d3..0000000000000000000000000000000000000000
--- a/models/sketch_adapter_v14.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-model:
- target: tencentarc.t21_adapter
- params:
- channels: [320, 640, 1280, 1280]
- nums_rb: 2
- ksize: 1
- sk: true
- cin: 64
- use_conv: false
\ No newline at end of file
diff --git a/modules/api/api.py b/modules/api/api.py
deleted file mode 100644
index 5a9ac5f1aa745e4dd8c9ed5a107dd840f05c0ba6..0000000000000000000000000000000000000000
--- a/modules/api/api.py
+++ /dev/null
@@ -1,551 +0,0 @@
-import base64
-import io
-import time
-import datetime
-import uvicorn
-from threading import Lock
-from io import BytesIO
-from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
-from fastapi.security import HTTPBasic, HTTPBasicCredentials
-from secrets import compare_digest
-
-import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
-from modules.api.models import *
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
-from modules.textual_inversion.preprocess import preprocess
-from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
-from modules.sd_models_config import find_checkpoint_config_near_filename
-from modules.realesrgan_model import get_realesrgan_models
-from modules import devices
-from typing import List
-import piexif
-import piexif.helper
-
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
-
-def script_name_to_index(name, scripts):
- try:
- return [script.title().lower() for script in scripts].index(name.lower())
- except:
- raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
-
-def validate_sampler_name(name):
- config = sd_samplers.all_samplers_map.get(name, None)
- if config is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
- return name
-
-def setUpscalers(req: dict):
- reqDict = vars(req)
- reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
- reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
- return reqDict
-
-def decode_base64_to_image(encoding):
- if encoding.startswith("data:image/"):
- encoding = encoding.split(";")[1].split(",")[1]
- try:
- image = Image.open(BytesIO(base64.b64decode(encoding)))
- return image
- except Exception as err:
- raise HTTPException(status_code=500, detail="Invalid encoded image")
-
-def encode_pil_to_base64(image):
- with io.BytesIO() as output_bytes:
-
- if opts.samples_format.lower() == 'png':
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in image.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
- image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
-
- elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
- parameters = image.info.get('parameters', None)
- exif_bytes = piexif.dump({
- "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
- })
- if opts.samples_format.lower() in ("jpg", "jpeg"):
- image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
- else:
- image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
-
- else:
- raise HTTPException(status_code=500, detail="Invalid image format")
-
- bytes_data = output_bytes.getvalue()
-
- return base64.b64encode(bytes_data)
-
-def api_middleware(app: FastAPI):
- @app.middleware("http")
- async def log_and_time(req: Request, call_next):
- ts = time.time()
- res: Response = await call_next(req)
- duration = str(round(time.time() - ts, 4))
- res.headers["X-Process-Time"] = duration
- endpoint = req.scope.get('path', 'err')
- if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
- print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
- t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
- code = res.status_code,
- ver = req.scope.get('http_version', '0.0'),
- cli = req.scope.get('client', ('0:0.0.0', 0))[0],
- prot = req.scope.get('scheme', 'err'),
- method = req.scope.get('method', 'err'),
- endpoint = endpoint,
- duration = duration,
- ))
- return res
-
-
-class Api:
- def __init__(self, app: FastAPI, queue_lock: Lock):
- if shared.cmd_opts.api_auth:
- self.credentials = dict()
- for auth in shared.cmd_opts.api_auth.split(","):
- user, password = auth.split(":")
- self.credentials[user] = password
-
- self.router = APIRouter()
- self.app = app
- self.queue_lock = queue_lock
- api_middleware(self.app)
- self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
- self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
- self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
- self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
- self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
- self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
- self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
- self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
- self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
- self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
- self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
- self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
- self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
- self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
- self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
- self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
-
- def add_api_route(self, path: str, endpoint, **kwargs):
- if shared.cmd_opts.api_auth:
- return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
- return self.app.add_api_route(path, endpoint, **kwargs)
-
- def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
- if credentials.username in self.credentials:
- if compare_digest(credentials.password, self.credentials[credentials.username]):
- return True
-
- raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
-
- def get_script(self, script_name, script_runner):
- if script_name is None:
- return None, None
-
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
-
- script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
- script = script_runner.selectable_scripts[script_idx]
- return script, script_idx
-
- def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
- script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
-
- populate = txt2imgreq.copy(update={ # Override __init__ params
- "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True
- }
- )
- if populate.sampler_name:
- populate.sampler_index = None # prevent a warning later on
-
- args = vars(populate)
- args.pop('script_name', None)
-
- with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
-
- shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args)
- else:
- processed = process_images(p)
- shared.state.end()
-
- b64images = list(map(encode_pil_to_base64, processed.images))
-
- return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
-
- def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
- init_images = img2imgreq.init_images
- if init_images is None:
- raise HTTPException(status_code=404, detail="Init image not found")
-
- script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
-
- mask = img2imgreq.mask
- if mask:
- mask = decode_base64_to_image(mask)
-
- populate = img2imgreq.copy(update={ # Override __init__ params
- "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
- "do_not_save_samples": True,
- "do_not_save_grid": True,
- "mask": mask
- }
- )
- if populate.sampler_name:
- populate.sampler_index = None # prevent a warning later on
-
- args = vars(populate)
- args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
- args.pop('script_name', None)
-
- with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
-
- shared.state.begin()
- if script is not None:
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
- p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args)
- else:
- processed = process_images(p)
- shared.state.end()
-
- b64images = list(map(encode_pil_to_base64, processed.images))
-
- if not img2imgreq.include_init_images:
- img2imgreq.init_images = None
- img2imgreq.mask = None
-
- return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
-
- def extras_single_image_api(self, req: ExtrasSingleImageRequest):
- reqDict = setUpscalers(req)
-
- reqDict['image'] = decode_base64_to_image(reqDict['image'])
-
- with self.queue_lock:
- result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
-
- return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
-
- def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
- reqDict = setUpscalers(req)
-
- def prepareFiles(file):
- file = decode_base64_to_file(file.data, file_path=file.name)
- file.orig_name = file.name
- return file
-
- reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
- reqDict.pop('imageList')
-
- with self.queue_lock:
- result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
-
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
-
- def pnginfoapi(self, req: PNGInfoRequest):
- if(not req.image.strip()):
- return PNGInfoResponse(info="")
-
- image = decode_base64_to_image(req.image.strip())
- if image is None:
- return PNGInfoResponse(info="")
-
- geninfo, items = images.read_info_from_image(image)
- if geninfo is None:
- geninfo = ""
-
- items = {**{'parameters': geninfo}, **items}
-
- return PNGInfoResponse(info=geninfo, items=items)
-
- def progressapi(self, req: ProgressRequest = Depends()):
- # copy from check_progress_call of ui.py
-
- if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
-
- # avoid dividing zero
- progress = 0.01
-
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
- time_since_start = time.time() - shared.state.time_start
- eta = (time_since_start/progress)
- eta_relative = eta-time_since_start
-
- progress = min(progress, 1)
-
- shared.state.set_current_image()
-
- current_image = None
- if shared.state.current_image and not req.skip_current_image:
- current_image = encode_pil_to_base64(shared.state.current_image)
-
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
-
- def interrogateapi(self, interrogatereq: InterrogateRequest):
- image_b64 = interrogatereq.image
- if image_b64 is None:
- raise HTTPException(status_code=404, detail="Image not found")
-
- img = decode_base64_to_image(image_b64)
- img = img.convert('RGB')
-
- # Override object param
- with self.queue_lock:
- if interrogatereq.model == "clip":
- processed = shared.interrogator.interrogate(img)
- elif interrogatereq.model == "deepdanbooru":
- processed = deepbooru.model.tag(img)
- else:
- raise HTTPException(status_code=404, detail="Model not found")
-
- return InterrogateResponse(caption=processed)
-
- def interruptapi(self):
- shared.state.interrupt()
-
- return {}
-
- def skip(self):
- shared.state.skip()
-
- def get_config(self):
- options = {}
- for key in shared.opts.data.keys():
- metadata = shared.opts.data_labels.get(key)
- if(metadata is not None):
- options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
- else:
- options.update({key: shared.opts.data.get(key, None)})
-
- return options
-
- def set_config(self, req: Dict[str, Any]):
- for k, v in req.items():
- shared.opts.set(k, v)
-
- shared.opts.save(shared.config_filename)
- return
-
- def get_cmd_flags(self):
- return vars(shared.cmd_opts)
-
- def get_samplers(self):
- return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
-
- def get_upscalers(self):
- return [
- {
- "name": upscaler.name,
- "model_name": upscaler.scaler.model_name,
- "model_path": upscaler.data_path,
- "model_url": None,
- "scale": upscaler.scale,
- }
- for upscaler in shared.sd_upscalers
- ]
-
- def get_sd_models(self):
- return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
-
- def get_hypernetworks(self):
- return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
-
- def get_face_restorers(self):
- return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
-
- def get_realesrgan_models(self):
- return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
-
- def get_prompt_styles(self):
- styleList = []
- for k in shared.prompt_styles.styles:
- style = shared.prompt_styles.styles[k]
- styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
-
- return styleList
-
- def get_embeddings(self):
- db = sd_hijack.model_hijack.embedding_db
-
- def convert_embedding(embedding):
- return {
- "step": embedding.step,
- "sd_checkpoint": embedding.sd_checkpoint,
- "sd_checkpoint_name": embedding.sd_checkpoint_name,
- "shape": embedding.shape,
- "vectors": embedding.vectors,
- }
-
- def convert_embeddings(embeddings):
- return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
-
- return {
- "loaded": convert_embeddings(db.word_embeddings),
- "skipped": convert_embeddings(db.skipped_embeddings),
- }
-
- def refresh_checkpoints(self):
- shared.refresh_checkpoints()
-
- def create_embedding(self, args: dict):
- try:
- shared.state.begin()
- filename = create_embedding(**args) # create empty embedding
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
- return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
- except AssertionError as e:
- shared.state.end()
- return TrainResponse(info = "create embedding error: {error}".format(error = e))
-
- def create_hypernetwork(self, args: dict):
- try:
- shared.state.begin()
- filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
- return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
- except AssertionError as e:
- shared.state.end()
- return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
-
- def preprocess(self, args: dict):
- try:
- shared.state.begin()
- preprocess(**args) # quick operation unless blip/booru interrogation is enabled
- shared.state.end()
- return PreprocessResponse(info = 'preprocess complete')
- except KeyError as e:
- shared.state.end()
- return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
- except AssertionError as e:
- shared.state.end()
- return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
- except FileNotFoundError as e:
- shared.state.end()
- return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
-
- def train_embedding(self, args: dict):
- try:
- shared.state.begin()
- apply_optimizations = shared.opts.training_xattention_optimizations
- error = None
- filename = ''
- if not apply_optimizations:
- sd_hijack.undo_optimizations()
- try:
- embedding, filename = train_embedding(**args) # can take a long time to complete
- except Exception as e:
- error = e
- finally:
- if not apply_optimizations:
- sd_hijack.apply_optimizations()
- shared.state.end()
- return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
- except AssertionError as msg:
- shared.state.end()
- return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
-
- def train_hypernetwork(self, args: dict):
- try:
- shared.state.begin()
- shared.loaded_hypernetworks = []
- apply_optimizations = shared.opts.training_xattention_optimizations
- error = None
- filename = ''
- if not apply_optimizations:
- sd_hijack.undo_optimizations()
- try:
- hypernetwork, filename = train_hypernetwork(**args)
- except Exception as e:
- error = e
- finally:
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
- if not apply_optimizations:
- sd_hijack.apply_optimizations()
- shared.state.end()
- return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
- except AssertionError as msg:
- shared.state.end()
- return TrainResponse(info="train embedding error: {error}".format(error=error))
-
- def get_memory(self):
- try:
- import os, psutil
- process = psutil.Process(os.getpid())
- res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
- ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
- ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
- except Exception as err:
- ram = { 'error': f'{err}' }
- try:
- import torch
- if torch.cuda.is_available():
- s = torch.cuda.mem_get_info()
- system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
- s = dict(torch.cuda.memory_stats(shared.device))
- allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
- reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
- active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
- inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
- warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
- cuda = {
- 'system': system,
- 'active': active,
- 'allocated': allocated,
- 'reserved': reserved,
- 'inactive': inactive,
- 'events': warnings,
- }
- else:
- cuda = { 'error': 'unavailable' }
- except Exception as err:
- cuda = { 'error': f'{err}' }
- return MemoryResponse(ram = ram, cuda = cuda)
-
- def launch(self, server_name, port):
- self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/api/models.py b/modules/api/models.py
deleted file mode 100644
index cba43d3b1807d547acda33256faf5db05dd216a6..0000000000000000000000000000000000000000
--- a/modules/api/models.py
+++ /dev/null
@@ -1,269 +0,0 @@
-import inspect
-from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
-from inflection import underscore
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
-
-API_NOT_ALLOWED = [
- "self",
- "kwargs",
- "sd_model",
- "outpath_samples",
- "outpath_grids",
- "sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
- "extra_generation_params",
- "overlay_images",
- "do_not_reload_embeddings",
- "seed_enable_extras",
- "prompt_for_display",
- "sampler_noise_scheduler_override",
- "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
- """Assistance Class for Pydantic Dynamic Model Generation"""
-
- field: str
- field_alias: str
- field_type: Any
- field_value: Any
- field_exclude: bool = False
-
-
-class PydanticModelGenerator:
- """
- Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
- source_data is a snapshot of the default values produced by the class
- params are the names of the actual keys required by __init__
- """
-
- def __init__(
- self,
- model_name: str = None,
- class_instance = None,
- additional_fields = None,
- ):
- def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
- field_type = v.annotation
-
- return Optional[field_type]
-
- def merge_class_params(class_):
- all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
- parameters = {}
- for classes in all_classes:
- parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
- return parameters
-
-
- self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
-
- self._model_def = [
- ModelDef(
- field=underscore(k),
- field_alias=k,
- field_type=field_type_generator(k, v),
- field_value=v.default
- )
- for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
- ]
-
- for fields in additional_fields:
- self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
- field_type=fields["type"],
- field_value=fields["default"],
- field_exclude=fields["exclude"] if "exclude" in fields else False))
-
- def generate_model(self):
- """
- Creates a pydantic BaseModel
- from the json and overrides provided at initialization
- """
- fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
- }
- DynamicModel = create_model(self._model_name, **fields)
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- return DynamicModel
-
-StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
- StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
-).generate_model()
-
-StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
- StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
-).generate_model()
-
-class TextToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: dict
- info: str
-
-class ImageToImageResponse(BaseModel):
- images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: dict
- info: str
-
-class ExtrasBaseRequest(BaseModel):
- resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
- show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
- gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
- codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
- codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
- upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
- upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
- upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
- upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
- upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
- upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
- extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
- upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
-
-class ExtraBaseResponse(BaseModel):
- html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
-
-class ExtrasSingleImageRequest(ExtrasBaseRequest):
- image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
-
-class ExtrasSingleImageResponse(ExtraBaseResponse):
- image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
-
-class FileData(BaseModel):
- data: str = Field(title="File data", description="Base64 representation of the file")
- name: str = Field(title="File name")
-
-class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
-
-class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: List[str] = Field(title="Images", description="The generated images in base64 format.")
-
-class PNGInfoRequest(BaseModel):
- image: str = Field(title="Image", description="The base64 encoded PNG image")
-
-class PNGInfoResponse(BaseModel):
- info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
- items: dict = Field(title="Items", description="An object containing all the info the image had")
-
-class ProgressRequest(BaseModel):
- skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
-
-class ProgressResponse(BaseModel):
- progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
- eta_relative: float = Field(title="ETA in secs")
- state: dict = Field(title="State", description="The current state snapshot")
- current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
- textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
-
-class InterrogateRequest(BaseModel):
- image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
- model: str = Field(default="clip", title="Model", description="The interrogate model used.")
-
-class InterrogateResponse(BaseModel):
- caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
-
-class TrainResponse(BaseModel):
- info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
-
-class CreateResponse(BaseModel):
- info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
-
-class PreprocessResponse(BaseModel):
- info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
-
-fields = {}
-for key, metadata in opts.data_labels.items():
- value = opts.data.get(key)
- optType = opts.typemap.get(type(metadata.default), type(value))
-
- if (metadata is not None):
- fields.update({key: (Optional[optType], Field(
- default=metadata.default ,description=metadata.label))})
- else:
- fields.update({key: (Optional[optType], Field())})
-
-OptionsModel = create_model("Options", **fields)
-
-flags = {}
-_options = vars(parser)['_option_string_actions']
-for key in _options:
- if(_options[key].dest != 'help'):
- flag = _options[key]
- _type = str
- if _options[key].default is not None: _type = type(_options[key].default)
- flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
-
-FlagsModel = create_model("Flags", **flags)
-
-class SamplerItem(BaseModel):
- name: str = Field(title="Name")
- aliases: List[str] = Field(title="Aliases")
- options: Dict[str, str] = Field(title="Options")
-
-class UpscalerItem(BaseModel):
- name: str = Field(title="Name")
- model_name: Optional[str] = Field(title="Model Name")
- model_path: Optional[str] = Field(title="Path")
- model_url: Optional[str] = Field(title="URL")
- scale: Optional[float] = Field(title="Scale")
-
-class SDModelItem(BaseModel):
- title: str = Field(title="Title")
- model_name: str = Field(title="Model Name")
- hash: Optional[str] = Field(title="Short hash")
- sha256: Optional[str] = Field(title="sha256 hash")
- filename: str = Field(title="Filename")
- config: Optional[str] = Field(title="Config file")
-
-class HypernetworkItem(BaseModel):
- name: str = Field(title="Name")
- path: Optional[str] = Field(title="Path")
-
-class FaceRestorerItem(BaseModel):
- name: str = Field(title="Name")
- cmd_dir: Optional[str] = Field(title="Path")
-
-class RealesrganItem(BaseModel):
- name: str = Field(title="Name")
- path: Optional[str] = Field(title="Path")
- scale: Optional[int] = Field(title="Scale")
-
-class PromptStyleItem(BaseModel):
- name: str = Field(title="Name")
- prompt: Optional[str] = Field(title="Prompt")
- negative_prompt: Optional[str] = Field(title="Negative Prompt")
-
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
-
-class EmbeddingItem(BaseModel):
- step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
- sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
- sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
- shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
- vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
-
-class EmbeddingsResponse(BaseModel):
- loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
-
-class MemoryResponse(BaseModel):
- ram: dict = Field(title="RAM", description="System memory stats")
- cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
diff --git a/modules/call_queue.py b/modules/call_queue.py
deleted file mode 100644
index 92097c15eb1327d12de87f4cc8c17e0482357919..0000000000000000000000000000000000000000
--- a/modules/call_queue.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import html
-import sys
-import threading
-import traceback
-import time
-
-from modules import shared, progress
-
-queue_lock = threading.Lock()
-
-
-def wrap_queued_call(func):
- def f(*args, **kwargs):
- with queue_lock:
- res = func(*args, **kwargs)
-
- return res
-
- return f
-
-
-def wrap_gradio_gpu_call(func, extra_outputs=None):
- def f(*args, **kwargs):
-
- # if the first argument is a string that says "task(...)", it is treated as a job id
- if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
- id_task = args[0]
- progress.add_task_to_queue(id_task)
- else:
- id_task = None
-
- with queue_lock:
- shared.state.begin()
- progress.start_task(id_task)
-
- try:
- res = func(*args, **kwargs)
- finally:
- progress.finish_task(id_task)
-
- shared.state.end()
-
- return res
-
- return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
-
-
-def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
- def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
- if run_memmon:
- shared.mem_mon.monitor()
- t = time.perf_counter()
-
- try:
- res = list(func(*args, **kwargs))
- except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {str(args)} {str(kwargs)}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
-
- shared.state.job = ""
- shared.state.job_count = 0
-
- if extra_outputs_array is None:
- extra_outputs_array = [None, '']
-
- res = extra_outputs_array + [f"{html.escape(type(e).__name__+': '+str(e))}
"]
-
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.job_count = 0
-
- if not add_stats:
- return tuple(res)
-
- elapsed = time.perf_counter() - t
- elapsed_m = int(elapsed // 60)
- elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
- if elapsed_m > 0:
- elapsed_text = f"{elapsed_m}m "+elapsed_text
-
- if run_memmon:
- mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
- active_peak = mem_stats['active_peak']
- reserved_peak = mem_stats['reserved_peak']
- sys_peak = mem_stats['system_peak']
- sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
-
- vram_html = f"Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)
"
- else:
- vram_html = ''
-
- # last item is always HTML
- res[-1] += f""
-
- return tuple(res)
-
- return f
-
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
deleted file mode 100644
index 11dcc3ee76511218c64977c2ecbb306cecd892c3..0000000000000000000000000000000000000000
--- a/modules/codeformer/codeformer_arch.py
+++ /dev/null
@@ -1,278 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-import math
-import numpy as np
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-from typing import Optional, List
-
-from modules.codeformer.vqgan_arch import *
-from basicsr.utils import get_root_logger
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def calc_mean_std(feat, eps=1e-5):
- """Calculate mean and std for adaptive_instance_normalization.
-
- Args:
- feat (Tensor): 4D tensor.
- eps (float): A small value added to the variance to avoid
- divide-by-zero. Default: 1e-5.
- """
- size = feat.size()
- assert len(size) == 4, 'The input feature should be 4D tensor.'
- b, c = size[:2]
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
- return feat_mean, feat_std
-
-
-def adaptive_instance_normalization(content_feat, style_feat):
- """Adaptive instance normalization.
-
- Adjust the reference features to have the similar color and illuminations
- as those in the degradate features.
-
- Args:
- content_feat (Tensor): The reference feature.
- style_feat (Tensor): The degradate features.
- """
- size = content_feat.size()
- style_mean, style_std = calc_mean_std(style_feat)
- content_mean, content_std = calc_mean_std(content_feat)
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
-
-
-class TransformerSALayer(nn.Module):
- def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
- # Implementation of Feedforward model - MLP
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
-
- self.norm1 = nn.LayerNorm(embed_dim)
- self.norm2 = nn.LayerNorm(embed_dim)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward(self, tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
-
- # self attention
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
-
- # ffn
- tgt2 = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout2(tgt2)
- return tgt
-
-class Fuse_sft_block(nn.Module):
- def __init__(self, in_ch, out_ch):
- super().__init__()
- self.encode_enc = ResBlock(2*in_ch, out_ch)
-
- self.scale = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- self.shift = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- def forward(self, enc_feat, dec_feat, w=1):
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
- scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
- return out
-
-
-@ARCH_REGISTRY.register()
-class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
- codebook_size=1024, latent_size=256,
- connect_list=['32', '64', '128', '256'],
- fix_modules=['quantize','generator']):
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
-
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
-
- self.connect_list = connect_list
- self.n_layers = n_layers
- self.dim_embd = dim_embd
- self.dim_mlp = dim_embd*2
-
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
- self.feat_emb = nn.Linear(256, self.dim_embd)
-
- # transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
- for _ in range(self.n_layers)])
-
- # logits_predict head
- self.idx_pred_layer = nn.Sequential(
- nn.LayerNorm(dim_embd),
- nn.Linear(dim_embd, codebook_size, bias=False))
-
- self.channels = {
- '16': 512,
- '32': 256,
- '64': 256,
- '128': 128,
- '256': 128,
- '512': 64,
- }
-
- # after second residual block for > 16, before attn layer for ==16
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
- # after first residual block for > 16, before attn layer for ==16
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
-
- # fuse_convs_dict
- self.fuse_convs_dict = nn.ModuleDict()
- for f_size in self.connect_list:
- in_ch = self.channels[f_size]
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
-
- def _init_weights(self, module):
- if isinstance(module, (nn.Linear, nn.Embedding)):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
-
- def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
- # ################### Encoder #####################
- enc_feat_dict = {}
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
- for i, block in enumerate(self.encoder.blocks):
- x = block(x)
- if i in out_list:
- enc_feat_dict[str(x.shape[-1])] = x.clone()
-
- lq_feat = x
- # ################# Transformer ###################
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
- # BCHW -> BC(HW) -> (HW)BC
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
- query_emb = feat_emb
- # Transformer encoder
- for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
-
- # output logits
- logits = self.idx_pred_layer(query_emb) # (hw)bn
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
-
- if code_only: # for training stage II
- # logits doesn't need softmax before cross_entropy loss
- return logits, lq_feat
-
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
- soft_one_hot = F.softmax(logits, dim=2)
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
-
- if detach_16:
- quant_feat = quant_feat.detach() # for training stage III
- if adain:
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
-
- # ################## Generator ####################
- x = quant_feat
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
-
- for i, block in enumerate(self.generator.blocks):
- x = block(x)
- if i in fuse_list: # fuse after i-th block
- f_size = str(x.shape[-1])
- if w>0:
- x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
- out = x
- # logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
\ No newline at end of file
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
deleted file mode 100644
index e729368383aa2d8c224289284ec5489d554f9a33..0000000000000000000000000000000000000000
--- a/modules/codeformer/vqgan_arch.py
+++ /dev/null
@@ -1,437 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-'''
-VQGAN code, adapted from the original created by the Unleashing Transformers authors:
-https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
-
-'''
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import copy
-from basicsr.utils import get_root_logger
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-@torch.jit.script
-def swish(x):
- return x*torch.sigmoid(x)
-
-
-# Define VQVAE classes
-class VectorQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
- self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
-
- def forward(self, z):
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.emb_dim)
-
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
-
- mean_distance = torch.mean(d)
- # find closest encodings
- # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
- min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
- # [0-1], higher score, higher confidence
- min_encoding_scores = torch.exp(-min_encoding_scores/10)
-
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # compute loss for embedding
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q, loss, {
- "perplexity": perplexity,
- "min_encodings": min_encodings,
- "min_encoding_indices": min_encoding_indices,
- "min_encoding_scores": min_encoding_scores,
- "mean_distance": mean_distance
- }
-
- def get_codebook_feat(self, indices, shape):
- # input indices: batch*token_num -> (batch*token_num)*1
- # shape: batch, height, width, channel
- indices = indices.view(-1,1)
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
- min_encodings.scatter_(1, indices, 1)
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None: # reshape back to match original input shape
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-class GumbelQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
- super().__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.straight_through = straight_through
- self.temperature = temp_init
- self.kl_weight = kl_weight
- self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
- self.embed = nn.Embedding(codebook_size, emb_dim)
-
- def forward(self, z):
- hard = self.straight_through if self.training else True
-
- logits = self.proj(z)
-
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
-
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
-
- # + kl divergence to the prior loss
- qy = F.softmax(logits, dim=1)
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
- min_encoding_indices = soft_one_hot.argmax(dim=1)
-
- return z_q, diff, {
- "min_encoding_indices": min_encoding_indices
- }
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
-
- def forward(self, x):
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
-
- def forward(self, x):
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
- x = self.conv(x)
-
- return x
-
-
-class ResBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None):
- super(ResBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = normalize(in_channels)
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- self.norm2 = normalize(out_channels)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x_in):
- x = x_in
- x = self.norm1(x)
- x = swish(x)
- x = self.conv1(x)
- x = self.norm2(x)
- x = swish(x)
- x = self.conv2(x)
- if self.in_channels != self.out_channels:
- x_in = self.conv_out(x_in)
-
- return x + x_in
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c)**(-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class Encoder(nn.Module):
- def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.attn_resolutions = attn_resolutions
-
- curr_res = self.resolution
- in_ch_mult = (1,)+tuple(ch_mult)
-
- blocks = []
- # initial convultion
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
-
- # residual and downsampling blocks, with attention on smaller res (16x16)
- for i in range(self.num_resolutions):
- block_in_ch = nf * in_ch_mult[i]
- block_out_ch = nf * ch_mult[i]
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
- if curr_res in attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != self.num_resolutions - 1:
- blocks.append(Downsample(block_in_ch))
- curr_res = curr_res // 2
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- # normalise and convert to latent size
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class Generator(nn.Module):
- def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
- self.num_resolutions = len(self.ch_mult)
- self.num_res_blocks = res_blocks
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.in_channels = emb_dim
- self.out_channels = 3
- block_in_ch = self.nf * self.ch_mult[-1]
- curr_res = self.resolution // 2 ** (self.num_resolutions-1)
-
- blocks = []
- # initial conv
- blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- for i in reversed(range(self.num_resolutions)):
- block_out_ch = self.nf * self.ch_mult[i]
-
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
-
- if curr_res in self.attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != 0:
- blocks.append(Upsample(block_in_ch))
- curr_res = curr_res * 2
-
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
-
- self.blocks = nn.ModuleList(blocks)
-
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-@ARCH_REGISTRY.register()
-class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
- beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
- super().__init__()
- logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
- self.codebook_size = codebook_size
- self.embed_dim = emb_dim
- self.ch_mult = ch_mult
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.quantizer_type = quantizer
- self.encoder = Encoder(
- self.in_channels,
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
- if self.quantizer_type == "nearest":
- self.beta = beta #0.25
- self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
- elif self.quantizer_type == "gumbel":
- self.gumbel_num_hiddens = emb_dim
- self.straight_through = gumbel_straight_through
- self.kl_weight = gumbel_kl_weight
- self.quantize = GumbelQuantizer(
- self.codebook_size,
- self.embed_dim,
- self.gumbel_num_hiddens,
- self.straight_through,
- self.kl_weight
- )
- self.generator = Generator(
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_ema' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
- logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- logger.info(f'vqgan is loaded from: {model_path} [params]')
- else:
- raise ValueError('Wrong params!')
-
-
- def forward(self, x):
- x = self.encoder(x)
- quant, codebook_loss, quant_stats = self.quantize(x)
- x = self.generator(quant)
- return x, codebook_loss, quant_stats
-
-
-
-# patch based discriminator
-@ARCH_REGISTRY.register()
-class VQGANDiscriminator(nn.Module):
- def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
- super().__init__()
-
- layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
- ndf_mult = 1
- ndf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n, 8)
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n_layers, 8)
-
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- layers += [
- nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
- self.main = nn.Sequential(*layers)
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_d' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- else:
- raise ValueError('Wrong params!')
-
- def forward(self, x):
- return self.main(x)
\ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
deleted file mode 100644
index 01fb7bd8486f48f12510b2ce52040902bb181545..0000000000000000000000000000000000000000
--- a/modules/codeformer_model.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-import sys
-import traceback
-
-import cv2
-import torch
-
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader
-from modules.paths import models_path
-
-# codeformer people made a choice to include modified basicsr library to their project which makes
-# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
-# I am making a choice to include some files from codeformer to work around this issue.
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
-model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-
-have_codeformer = False
-codeformer = None
-
-
-def setup_model(dirname):
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- path = modules.paths.paths.get("CodeFormer", None)
- if path is None:
- return
-
- try:
- from torchvision.transforms.functional import normalize
- from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils.download_util import load_file_from_url
- from basicsr.utils import imwrite, img2tensor, tensor2img
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
- from facelib.detection.retinaface import retinaface
- from modules.shared import cmd_opts
-
- net_class = CodeFormer
-
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
- def name(self):
- return "CodeFormer"
-
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
-
- def create_models(self):
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
-
- self.net = net
- self.face_helper = face_helper
-
- return net, face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
-
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
-
- self.send_model_to(devices.device_codeformer)
-
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
-
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
-
- try:
- with torch.no_grad():
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
-
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
-
- self.face_helper.get_inverse_affine(None)
-
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
-
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
-
- self.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
-
- return restored_img
-
- global have_codeformer
- have_codeformer = True
-
- global codeformer
- codeformer = FaceRestorerCodeFormer(dirname)
- shared.face_restorers.append(codeformer)
-
- except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- # sys.path = stored_sys_path
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
deleted file mode 100644
index 122fce7f569dbd28f9c6d83af874bb3efed34a5e..0000000000000000000000000000000000000000
--- a/modules/deepbooru.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import re
-
-import torch
-from PIL import Image
-import numpy as np
-
-from modules import modelloader, paths, deepbooru_model, devices, images, shared
-
-re_special = re.compile(r'([\\()])')
-
-
-class DeepDanbooru:
- def __init__(self):
- self.model = None
-
- def load(self):
- if self.model is not None:
- return
-
- files = modelloader.load_models(
- model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
- model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
- ext_filter=[".pt"],
- download_name='model-resnet_custom_v3.pt',
- )
-
- self.model = deepbooru_model.DeepDanbooruModel()
- self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
-
- self.model.eval()
- self.model.to(devices.cpu, devices.dtype)
-
- def start(self):
- self.load()
- self.model.to(devices.device)
-
- def stop(self):
- if not shared.opts.interrogate_keep_models_in_memory:
- self.model.to(devices.cpu)
- devices.torch_gc()
-
- def tag(self, pil_image):
- self.start()
- res = self.tag_multi(pil_image)
- self.stop()
-
- return res
-
- def tag_multi(self, pil_image, force_disable_ranks=False):
- threshold = shared.opts.interrogate_deepbooru_score_threshold
- use_spaces = shared.opts.deepbooru_use_spaces
- use_escape = shared.opts.deepbooru_escape
- alpha_sort = shared.opts.deepbooru_sort_alpha
- include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
-
- pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
- a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
-
- with torch.no_grad(), devices.autocast():
- x = torch.from_numpy(a).to(devices.device)
- y = self.model(x)[0].detach().cpu().numpy()
-
- probability_dict = {}
-
- for tag, probability in zip(self.model.tags, y):
- if probability < threshold:
- continue
-
- if tag.startswith("rating:"):
- continue
-
- probability_dict[tag] = probability
-
- if alpha_sort:
- tags = sorted(probability_dict)
- else:
- tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
-
- res = []
-
- filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
-
- for tag in [x for x in tags if x not in filtertags]:
- probability = probability_dict[tag]
- tag_outformat = tag
- if use_spaces:
- tag_outformat = tag_outformat.replace('_', ' ')
- if use_escape:
- tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
- if include_ranks:
- tag_outformat = f"({tag_outformat}:{probability:.3f})"
-
- res.append(tag_outformat)
-
- return ", ".join(res)
-
-
-model = DeepDanbooru()
diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py
deleted file mode 100644
index 83d2ff0902f965ac3c69d830203ad36d0b067089..0000000000000000000000000000000000000000
--- a/modules/deepbooru_model.py
+++ /dev/null
@@ -1,678 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from modules import devices
-
-# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
-
-
-class DeepDanbooruModel(nn.Module):
- def __init__(self):
- super(DeepDanbooruModel, self).__init__()
-
- self.tags = []
-
- self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
- self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
- self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
- self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
- self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
- self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
- self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
- self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
- self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
- self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
- self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
- self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
- self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
- self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
- self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
- self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
- self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
- self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
- self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
- self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
- self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
- self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
- self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
- self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
- self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
- self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
- self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
- self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
- self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
- self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
- self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
- self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
- self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
- self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
- self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
- self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
- self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
- self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
-
- def forward(self, *inputs):
- t_358, = inputs
- t_359 = t_358.permute(*[0, 3, 1, 2])
- t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
- t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
- t_361 = F.relu(t_360)
- t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
- t_362 = self.n_MaxPool_0(t_361)
- t_363 = self.n_Conv_1(t_362)
- t_364 = self.n_Conv_2(t_362)
- t_365 = F.relu(t_364)
- t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
- t_366 = self.n_Conv_3(t_365_padded)
- t_367 = F.relu(t_366)
- t_368 = self.n_Conv_4(t_367)
- t_369 = torch.add(t_368, t_363)
- t_370 = F.relu(t_369)
- t_371 = self.n_Conv_5(t_370)
- t_372 = F.relu(t_371)
- t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
- t_373 = self.n_Conv_6(t_372_padded)
- t_374 = F.relu(t_373)
- t_375 = self.n_Conv_7(t_374)
- t_376 = torch.add(t_375, t_370)
- t_377 = F.relu(t_376)
- t_378 = self.n_Conv_8(t_377)
- t_379 = F.relu(t_378)
- t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
- t_380 = self.n_Conv_9(t_379_padded)
- t_381 = F.relu(t_380)
- t_382 = self.n_Conv_10(t_381)
- t_383 = torch.add(t_382, t_377)
- t_384 = F.relu(t_383)
- t_385 = self.n_Conv_11(t_384)
- t_386 = self.n_Conv_12(t_384)
- t_387 = F.relu(t_386)
- t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
- t_388 = self.n_Conv_13(t_387_padded)
- t_389 = F.relu(t_388)
- t_390 = self.n_Conv_14(t_389)
- t_391 = torch.add(t_390, t_385)
- t_392 = F.relu(t_391)
- t_393 = self.n_Conv_15(t_392)
- t_394 = F.relu(t_393)
- t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
- t_395 = self.n_Conv_16(t_394_padded)
- t_396 = F.relu(t_395)
- t_397 = self.n_Conv_17(t_396)
- t_398 = torch.add(t_397, t_392)
- t_399 = F.relu(t_398)
- t_400 = self.n_Conv_18(t_399)
- t_401 = F.relu(t_400)
- t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
- t_402 = self.n_Conv_19(t_401_padded)
- t_403 = F.relu(t_402)
- t_404 = self.n_Conv_20(t_403)
- t_405 = torch.add(t_404, t_399)
- t_406 = F.relu(t_405)
- t_407 = self.n_Conv_21(t_406)
- t_408 = F.relu(t_407)
- t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
- t_409 = self.n_Conv_22(t_408_padded)
- t_410 = F.relu(t_409)
- t_411 = self.n_Conv_23(t_410)
- t_412 = torch.add(t_411, t_406)
- t_413 = F.relu(t_412)
- t_414 = self.n_Conv_24(t_413)
- t_415 = F.relu(t_414)
- t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
- t_416 = self.n_Conv_25(t_415_padded)
- t_417 = F.relu(t_416)
- t_418 = self.n_Conv_26(t_417)
- t_419 = torch.add(t_418, t_413)
- t_420 = F.relu(t_419)
- t_421 = self.n_Conv_27(t_420)
- t_422 = F.relu(t_421)
- t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
- t_423 = self.n_Conv_28(t_422_padded)
- t_424 = F.relu(t_423)
- t_425 = self.n_Conv_29(t_424)
- t_426 = torch.add(t_425, t_420)
- t_427 = F.relu(t_426)
- t_428 = self.n_Conv_30(t_427)
- t_429 = F.relu(t_428)
- t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
- t_430 = self.n_Conv_31(t_429_padded)
- t_431 = F.relu(t_430)
- t_432 = self.n_Conv_32(t_431)
- t_433 = torch.add(t_432, t_427)
- t_434 = F.relu(t_433)
- t_435 = self.n_Conv_33(t_434)
- t_436 = F.relu(t_435)
- t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
- t_437 = self.n_Conv_34(t_436_padded)
- t_438 = F.relu(t_437)
- t_439 = self.n_Conv_35(t_438)
- t_440 = torch.add(t_439, t_434)
- t_441 = F.relu(t_440)
- t_442 = self.n_Conv_36(t_441)
- t_443 = self.n_Conv_37(t_441)
- t_444 = F.relu(t_443)
- t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
- t_445 = self.n_Conv_38(t_444_padded)
- t_446 = F.relu(t_445)
- t_447 = self.n_Conv_39(t_446)
- t_448 = torch.add(t_447, t_442)
- t_449 = F.relu(t_448)
- t_450 = self.n_Conv_40(t_449)
- t_451 = F.relu(t_450)
- t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
- t_452 = self.n_Conv_41(t_451_padded)
- t_453 = F.relu(t_452)
- t_454 = self.n_Conv_42(t_453)
- t_455 = torch.add(t_454, t_449)
- t_456 = F.relu(t_455)
- t_457 = self.n_Conv_43(t_456)
- t_458 = F.relu(t_457)
- t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
- t_459 = self.n_Conv_44(t_458_padded)
- t_460 = F.relu(t_459)
- t_461 = self.n_Conv_45(t_460)
- t_462 = torch.add(t_461, t_456)
- t_463 = F.relu(t_462)
- t_464 = self.n_Conv_46(t_463)
- t_465 = F.relu(t_464)
- t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
- t_466 = self.n_Conv_47(t_465_padded)
- t_467 = F.relu(t_466)
- t_468 = self.n_Conv_48(t_467)
- t_469 = torch.add(t_468, t_463)
- t_470 = F.relu(t_469)
- t_471 = self.n_Conv_49(t_470)
- t_472 = F.relu(t_471)
- t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
- t_473 = self.n_Conv_50(t_472_padded)
- t_474 = F.relu(t_473)
- t_475 = self.n_Conv_51(t_474)
- t_476 = torch.add(t_475, t_470)
- t_477 = F.relu(t_476)
- t_478 = self.n_Conv_52(t_477)
- t_479 = F.relu(t_478)
- t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
- t_480 = self.n_Conv_53(t_479_padded)
- t_481 = F.relu(t_480)
- t_482 = self.n_Conv_54(t_481)
- t_483 = torch.add(t_482, t_477)
- t_484 = F.relu(t_483)
- t_485 = self.n_Conv_55(t_484)
- t_486 = F.relu(t_485)
- t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
- t_487 = self.n_Conv_56(t_486_padded)
- t_488 = F.relu(t_487)
- t_489 = self.n_Conv_57(t_488)
- t_490 = torch.add(t_489, t_484)
- t_491 = F.relu(t_490)
- t_492 = self.n_Conv_58(t_491)
- t_493 = F.relu(t_492)
- t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
- t_494 = self.n_Conv_59(t_493_padded)
- t_495 = F.relu(t_494)
- t_496 = self.n_Conv_60(t_495)
- t_497 = torch.add(t_496, t_491)
- t_498 = F.relu(t_497)
- t_499 = self.n_Conv_61(t_498)
- t_500 = F.relu(t_499)
- t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
- t_501 = self.n_Conv_62(t_500_padded)
- t_502 = F.relu(t_501)
- t_503 = self.n_Conv_63(t_502)
- t_504 = torch.add(t_503, t_498)
- t_505 = F.relu(t_504)
- t_506 = self.n_Conv_64(t_505)
- t_507 = F.relu(t_506)
- t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
- t_508 = self.n_Conv_65(t_507_padded)
- t_509 = F.relu(t_508)
- t_510 = self.n_Conv_66(t_509)
- t_511 = torch.add(t_510, t_505)
- t_512 = F.relu(t_511)
- t_513 = self.n_Conv_67(t_512)
- t_514 = F.relu(t_513)
- t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
- t_515 = self.n_Conv_68(t_514_padded)
- t_516 = F.relu(t_515)
- t_517 = self.n_Conv_69(t_516)
- t_518 = torch.add(t_517, t_512)
- t_519 = F.relu(t_518)
- t_520 = self.n_Conv_70(t_519)
- t_521 = F.relu(t_520)
- t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
- t_522 = self.n_Conv_71(t_521_padded)
- t_523 = F.relu(t_522)
- t_524 = self.n_Conv_72(t_523)
- t_525 = torch.add(t_524, t_519)
- t_526 = F.relu(t_525)
- t_527 = self.n_Conv_73(t_526)
- t_528 = F.relu(t_527)
- t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
- t_529 = self.n_Conv_74(t_528_padded)
- t_530 = F.relu(t_529)
- t_531 = self.n_Conv_75(t_530)
- t_532 = torch.add(t_531, t_526)
- t_533 = F.relu(t_532)
- t_534 = self.n_Conv_76(t_533)
- t_535 = F.relu(t_534)
- t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
- t_536 = self.n_Conv_77(t_535_padded)
- t_537 = F.relu(t_536)
- t_538 = self.n_Conv_78(t_537)
- t_539 = torch.add(t_538, t_533)
- t_540 = F.relu(t_539)
- t_541 = self.n_Conv_79(t_540)
- t_542 = F.relu(t_541)
- t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
- t_543 = self.n_Conv_80(t_542_padded)
- t_544 = F.relu(t_543)
- t_545 = self.n_Conv_81(t_544)
- t_546 = torch.add(t_545, t_540)
- t_547 = F.relu(t_546)
- t_548 = self.n_Conv_82(t_547)
- t_549 = F.relu(t_548)
- t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
- t_550 = self.n_Conv_83(t_549_padded)
- t_551 = F.relu(t_550)
- t_552 = self.n_Conv_84(t_551)
- t_553 = torch.add(t_552, t_547)
- t_554 = F.relu(t_553)
- t_555 = self.n_Conv_85(t_554)
- t_556 = F.relu(t_555)
- t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
- t_557 = self.n_Conv_86(t_556_padded)
- t_558 = F.relu(t_557)
- t_559 = self.n_Conv_87(t_558)
- t_560 = torch.add(t_559, t_554)
- t_561 = F.relu(t_560)
- t_562 = self.n_Conv_88(t_561)
- t_563 = F.relu(t_562)
- t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
- t_564 = self.n_Conv_89(t_563_padded)
- t_565 = F.relu(t_564)
- t_566 = self.n_Conv_90(t_565)
- t_567 = torch.add(t_566, t_561)
- t_568 = F.relu(t_567)
- t_569 = self.n_Conv_91(t_568)
- t_570 = F.relu(t_569)
- t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
- t_571 = self.n_Conv_92(t_570_padded)
- t_572 = F.relu(t_571)
- t_573 = self.n_Conv_93(t_572)
- t_574 = torch.add(t_573, t_568)
- t_575 = F.relu(t_574)
- t_576 = self.n_Conv_94(t_575)
- t_577 = F.relu(t_576)
- t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
- t_578 = self.n_Conv_95(t_577_padded)
- t_579 = F.relu(t_578)
- t_580 = self.n_Conv_96(t_579)
- t_581 = torch.add(t_580, t_575)
- t_582 = F.relu(t_581)
- t_583 = self.n_Conv_97(t_582)
- t_584 = F.relu(t_583)
- t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
- t_585 = self.n_Conv_98(t_584_padded)
- t_586 = F.relu(t_585)
- t_587 = self.n_Conv_99(t_586)
- t_588 = self.n_Conv_100(t_582)
- t_589 = torch.add(t_587, t_588)
- t_590 = F.relu(t_589)
- t_591 = self.n_Conv_101(t_590)
- t_592 = F.relu(t_591)
- t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
- t_593 = self.n_Conv_102(t_592_padded)
- t_594 = F.relu(t_593)
- t_595 = self.n_Conv_103(t_594)
- t_596 = torch.add(t_595, t_590)
- t_597 = F.relu(t_596)
- t_598 = self.n_Conv_104(t_597)
- t_599 = F.relu(t_598)
- t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
- t_600 = self.n_Conv_105(t_599_padded)
- t_601 = F.relu(t_600)
- t_602 = self.n_Conv_106(t_601)
- t_603 = torch.add(t_602, t_597)
- t_604 = F.relu(t_603)
- t_605 = self.n_Conv_107(t_604)
- t_606 = F.relu(t_605)
- t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
- t_607 = self.n_Conv_108(t_606_padded)
- t_608 = F.relu(t_607)
- t_609 = self.n_Conv_109(t_608)
- t_610 = torch.add(t_609, t_604)
- t_611 = F.relu(t_610)
- t_612 = self.n_Conv_110(t_611)
- t_613 = F.relu(t_612)
- t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
- t_614 = self.n_Conv_111(t_613_padded)
- t_615 = F.relu(t_614)
- t_616 = self.n_Conv_112(t_615)
- t_617 = torch.add(t_616, t_611)
- t_618 = F.relu(t_617)
- t_619 = self.n_Conv_113(t_618)
- t_620 = F.relu(t_619)
- t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
- t_621 = self.n_Conv_114(t_620_padded)
- t_622 = F.relu(t_621)
- t_623 = self.n_Conv_115(t_622)
- t_624 = torch.add(t_623, t_618)
- t_625 = F.relu(t_624)
- t_626 = self.n_Conv_116(t_625)
- t_627 = F.relu(t_626)
- t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
- t_628 = self.n_Conv_117(t_627_padded)
- t_629 = F.relu(t_628)
- t_630 = self.n_Conv_118(t_629)
- t_631 = torch.add(t_630, t_625)
- t_632 = F.relu(t_631)
- t_633 = self.n_Conv_119(t_632)
- t_634 = F.relu(t_633)
- t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
- t_635 = self.n_Conv_120(t_634_padded)
- t_636 = F.relu(t_635)
- t_637 = self.n_Conv_121(t_636)
- t_638 = torch.add(t_637, t_632)
- t_639 = F.relu(t_638)
- t_640 = self.n_Conv_122(t_639)
- t_641 = F.relu(t_640)
- t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
- t_642 = self.n_Conv_123(t_641_padded)
- t_643 = F.relu(t_642)
- t_644 = self.n_Conv_124(t_643)
- t_645 = torch.add(t_644, t_639)
- t_646 = F.relu(t_645)
- t_647 = self.n_Conv_125(t_646)
- t_648 = F.relu(t_647)
- t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
- t_649 = self.n_Conv_126(t_648_padded)
- t_650 = F.relu(t_649)
- t_651 = self.n_Conv_127(t_650)
- t_652 = torch.add(t_651, t_646)
- t_653 = F.relu(t_652)
- t_654 = self.n_Conv_128(t_653)
- t_655 = F.relu(t_654)
- t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
- t_656 = self.n_Conv_129(t_655_padded)
- t_657 = F.relu(t_656)
- t_658 = self.n_Conv_130(t_657)
- t_659 = torch.add(t_658, t_653)
- t_660 = F.relu(t_659)
- t_661 = self.n_Conv_131(t_660)
- t_662 = F.relu(t_661)
- t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
- t_663 = self.n_Conv_132(t_662_padded)
- t_664 = F.relu(t_663)
- t_665 = self.n_Conv_133(t_664)
- t_666 = torch.add(t_665, t_660)
- t_667 = F.relu(t_666)
- t_668 = self.n_Conv_134(t_667)
- t_669 = F.relu(t_668)
- t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
- t_670 = self.n_Conv_135(t_669_padded)
- t_671 = F.relu(t_670)
- t_672 = self.n_Conv_136(t_671)
- t_673 = torch.add(t_672, t_667)
- t_674 = F.relu(t_673)
- t_675 = self.n_Conv_137(t_674)
- t_676 = F.relu(t_675)
- t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
- t_677 = self.n_Conv_138(t_676_padded)
- t_678 = F.relu(t_677)
- t_679 = self.n_Conv_139(t_678)
- t_680 = torch.add(t_679, t_674)
- t_681 = F.relu(t_680)
- t_682 = self.n_Conv_140(t_681)
- t_683 = F.relu(t_682)
- t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
- t_684 = self.n_Conv_141(t_683_padded)
- t_685 = F.relu(t_684)
- t_686 = self.n_Conv_142(t_685)
- t_687 = torch.add(t_686, t_681)
- t_688 = F.relu(t_687)
- t_689 = self.n_Conv_143(t_688)
- t_690 = F.relu(t_689)
- t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
- t_691 = self.n_Conv_144(t_690_padded)
- t_692 = F.relu(t_691)
- t_693 = self.n_Conv_145(t_692)
- t_694 = torch.add(t_693, t_688)
- t_695 = F.relu(t_694)
- t_696 = self.n_Conv_146(t_695)
- t_697 = F.relu(t_696)
- t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
- t_698 = self.n_Conv_147(t_697_padded)
- t_699 = F.relu(t_698)
- t_700 = self.n_Conv_148(t_699)
- t_701 = torch.add(t_700, t_695)
- t_702 = F.relu(t_701)
- t_703 = self.n_Conv_149(t_702)
- t_704 = F.relu(t_703)
- t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
- t_705 = self.n_Conv_150(t_704_padded)
- t_706 = F.relu(t_705)
- t_707 = self.n_Conv_151(t_706)
- t_708 = torch.add(t_707, t_702)
- t_709 = F.relu(t_708)
- t_710 = self.n_Conv_152(t_709)
- t_711 = F.relu(t_710)
- t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
- t_712 = self.n_Conv_153(t_711_padded)
- t_713 = F.relu(t_712)
- t_714 = self.n_Conv_154(t_713)
- t_715 = torch.add(t_714, t_709)
- t_716 = F.relu(t_715)
- t_717 = self.n_Conv_155(t_716)
- t_718 = F.relu(t_717)
- t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
- t_719 = self.n_Conv_156(t_718_padded)
- t_720 = F.relu(t_719)
- t_721 = self.n_Conv_157(t_720)
- t_722 = torch.add(t_721, t_716)
- t_723 = F.relu(t_722)
- t_724 = self.n_Conv_158(t_723)
- t_725 = self.n_Conv_159(t_723)
- t_726 = F.relu(t_725)
- t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
- t_727 = self.n_Conv_160(t_726_padded)
- t_728 = F.relu(t_727)
- t_729 = self.n_Conv_161(t_728)
- t_730 = torch.add(t_729, t_724)
- t_731 = F.relu(t_730)
- t_732 = self.n_Conv_162(t_731)
- t_733 = F.relu(t_732)
- t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
- t_734 = self.n_Conv_163(t_733_padded)
- t_735 = F.relu(t_734)
- t_736 = self.n_Conv_164(t_735)
- t_737 = torch.add(t_736, t_731)
- t_738 = F.relu(t_737)
- t_739 = self.n_Conv_165(t_738)
- t_740 = F.relu(t_739)
- t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
- t_741 = self.n_Conv_166(t_740_padded)
- t_742 = F.relu(t_741)
- t_743 = self.n_Conv_167(t_742)
- t_744 = torch.add(t_743, t_738)
- t_745 = F.relu(t_744)
- t_746 = self.n_Conv_168(t_745)
- t_747 = self.n_Conv_169(t_745)
- t_748 = F.relu(t_747)
- t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
- t_749 = self.n_Conv_170(t_748_padded)
- t_750 = F.relu(t_749)
- t_751 = self.n_Conv_171(t_750)
- t_752 = torch.add(t_751, t_746)
- t_753 = F.relu(t_752)
- t_754 = self.n_Conv_172(t_753)
- t_755 = F.relu(t_754)
- t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
- t_756 = self.n_Conv_173(t_755_padded)
- t_757 = F.relu(t_756)
- t_758 = self.n_Conv_174(t_757)
- t_759 = torch.add(t_758, t_753)
- t_760 = F.relu(t_759)
- t_761 = self.n_Conv_175(t_760)
- t_762 = F.relu(t_761)
- t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
- t_763 = self.n_Conv_176(t_762_padded)
- t_764 = F.relu(t_763)
- t_765 = self.n_Conv_177(t_764)
- t_766 = torch.add(t_765, t_760)
- t_767 = F.relu(t_766)
- t_768 = self.n_Conv_178(t_767)
- t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
- t_770 = torch.squeeze(t_769, 3)
- t_770 = torch.squeeze(t_770, 2)
- t_771 = torch.sigmoid(t_770)
- return t_771
-
- def load_state_dict(self, state_dict, **kwargs):
- self.tags = state_dict.get('tags', [])
-
- super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
-
diff --git a/modules/devices.py b/modules/devices.py
deleted file mode 100644
index 52c3e7cd773f9c89857dfce14b37d63cb6329fac..0000000000000000000000000000000000000000
--- a/modules/devices.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import sys
-import contextlib
-import torch
-from modules import errors
-
-if sys.platform == "darwin":
- from modules import mac_specific
-
-
-def has_mps() -> bool:
- if sys.platform != "darwin":
- return False
- else:
- return mac_specific.has_mps
-
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
-
-def get_cuda_device_string():
- from modules import shared
-
- if shared.cmd_opts.device_id is not None:
- return f"cuda:{shared.cmd_opts.device_id}"
-
- return "cuda"
-
-
-def get_optimal_device_name():
- if torch.cuda.is_available():
- return get_cuda_device_string()
-
- if has_mps():
- return "mps"
-
- return "cpu"
-
-
-def get_optimal_device():
- return torch.device(get_optimal_device_name())
-
-
-def get_device_for(task):
- from modules import shared
-
- if task in shared.cmd_opts.use_cpu:
- return cpu
-
- return get_optimal_device()
-
-
-def torch_gc():
- if torch.cuda.is_available():
- with torch.cuda.device(get_cuda_device_string()):
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
-
-
-def enable_tf32():
- if torch.cuda.is_available():
-
- # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
- # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
- torch.backends.cudnn.benchmark = True
-
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
-
-
-errors.run(enable_tf32, "Enabling TF32")
-
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
-unet_needs_upcast = False
-
-
-def cond_cast_unet(input):
- return input.to(dtype_unet) if unet_needs_upcast else input
-
-
-def cond_cast_float(input):
- return input.float() if unet_needs_upcast else input
-
-
-def randn(seed, shape):
- torch.manual_seed(seed)
- if device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
- return torch.randn(shape, device=device)
-
-
-def randn_without_seed(shape):
- if device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
- return torch.randn(shape, device=device)
-
-
-def autocast(disable=False):
- from modules import shared
-
- if disable:
- return contextlib.nullcontext()
-
- if dtype == torch.float32 or shared.cmd_opts.precision == "full":
- return contextlib.nullcontext()
-
- return torch.autocast("cuda")
-
-
-def without_autocast(disable=False):
- return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
-
-
-class NansException(Exception):
- pass
-
-
-def test_for_nans(x, where):
- from modules import shared
-
- if shared.cmd_opts.disable_nan_check:
- return
-
- if not torch.all(torch.isnan(x)).item():
- return
-
- if where == "unet":
- message = "A tensor with all NaNs was produced in Unet."
-
- if not shared.cmd_opts.no_half:
- message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
-
- elif where == "vae":
- message = "A tensor with all NaNs was produced in VAE."
-
- if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
- message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
- else:
- message = "A tensor with all NaNs was produced."
-
- message += " Use --disable-nan-check commandline argument to disable this check."
-
- raise NansException(message)
diff --git a/modules/errors.py b/modules/errors.py
deleted file mode 100644
index f6b80dbbde7947511b58eae309f3b077c8c09fb5..0000000000000000000000000000000000000000
--- a/modules/errors.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import sys
-import traceback
-
-
-def print_error_explanation(message):
- lines = message.strip().split("\n")
- max_len = max([len(x) for x in lines])
-
- print('=' * max_len, file=sys.stderr)
- for line in lines:
- print(line, file=sys.stderr)
- print('=' * max_len, file=sys.stderr)
-
-
-def display(e: Exception, task):
- print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- message = str(e)
- if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
- print_error_explanation("""
-The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
-See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
- """)
-
-
-already_displayed = {}
-
-
-def display_once(e: Exception, task):
- if task in already_displayed:
- return
-
- display(e, task)
-
- already_displayed[task] = 1
-
-
-def run(code, task):
- try:
- code()
- except Exception as e:
- display(task, e)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
deleted file mode 100644
index 9a9c38f1f64bfa2e96f20bd09666174ab16cfea1..0000000000000000000000000000000000000000
--- a/modules/esrgan_model.py
+++ /dev/null
@@ -1,233 +0,0 @@
-import os
-
-import numpy as np
-import torch
-from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import opts
-
-
-
-def mod2normal(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- if 'conv_first.weight' in state_dict:
- crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
- state_dict = crt_net
- return state_dict
-
-
-def resrgan2normal(state_dict, nb=23):
- # this code is copied from https://github.com/victorca25/iNNfer
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
- re8x = 0
- crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if "rdb" in k:
- ori_k = k.replace('body.', 'model.1.sub.')
- ori_k = ori_k.replace('.rdb', '.RDB')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
-
- if 'conv_up3.weight' in state_dict:
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
- re8x = 3
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
-
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
-
- state_dict = crt_net
- return state_dict
-
-
-def infer_params(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- scale2x = 0
- scalemin = 6
- n_uplayer = 0
- plus = False
-
- for block in list(state_dict):
- parts = block.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- elif n_parts == 3:
- part_num = int(parts[1])
- if (part_num > scalemin
- and parts[0] == "model"
- and parts[2] == "weight"):
- scale2x += 1
- if part_num > n_uplayer:
- n_uplayer = part_num
- out_nc = state_dict[block].shape[0]
- if not plus and "conv1x1" in block:
- plus = True
-
- nf = state_dict["model.0.weight"].shape[0]
- in_nc = state_dict["model.0.weight"].shape[1]
- out_nc = out_nc
- scale = 2 ** scale2x
-
- return in_nc, out_nc, nf, nb, plus, scale
-
-
-class UpscalerESRGAN(Upscaler):
- def __init__(self, dirname):
- self.name = "ESRGAN"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
- self.model_name = "ESRGAN_4x"
- self.scalers = []
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
- scalers = []
- if len(model_paths) == 0:
- scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
- scalers.append(scaler_data)
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
-
- scaler_data = UpscalerData(name, file, self, 4)
- self.scalers.append(scaler_data)
-
- def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
- return img
- model.to(devices.device_esrgan)
- img = esrgan_upscale(model, img)
- return img
-
- def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="%s.pth" % self.model_name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print("Unable to load %s from %s" % (self.model_path, filename))
- return None
-
- state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
-
- if "params_ema" in state_dict:
- state_dict = state_dict["params_ema"]
- elif "params" in state_dict:
- state_dict = state_dict["params"]
- num_conv = 16 if "realesr-animevideov3" in filename else 32
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
- model.load_state_dict(state_dict)
- model.eval()
- return model
-
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
- state_dict = resrgan2normal(state_dict, nb)
- elif "conv_first.weight" in state_dict:
- state_dict = mod2normal(state_dict)
- elif "model.0.weight" not in state_dict:
- raise Exception("The file is not a recognized ESRGAN model.")
-
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
-
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
- model.load_state_dict(state_dict)
- model.eval()
-
- return model
-
-
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
-
-
-def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
deleted file mode 100644
index 1b52b0f5e9605f919edccd0c4bedc9163af61761..0000000000000000000000000000000000000000
--- a/modules/esrgan_model_arch.py
+++ /dev/null
@@ -1,464 +0,0 @@
-# this file is adapted from https://github.com/victorca25/iNNfer
-
-from collections import OrderedDict
-import math
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-####################
-# RRDBNet Generator
-####################
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
- finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
- n_upscale = int(math.log(upscale, 2))
- if upscale == 3:
- n_upscale = 1
-
- self.resrgan_scale = 0
- if in_nc % 16 == 0:
- self.resrgan_scale = 1
- elif in_nc != 4 and in_nc % 4 == 0:
- self.resrgan_scale = 2
-
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
-
- if upsample_mode == 'upconv':
- upsample_block = upconv_block
- elif upsample_mode == 'pixelshuffle':
- upsample_block = pixelshuffle_block
- else:
- raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
- if upscale == 3:
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
- else:
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
-
- outact = act(finalact) if finalact else None
-
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
- *upsampler, HR_conv0, HR_conv1, outact)
-
- def forward(self, x, outm=None):
- if self.resrgan_scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- elif self.resrgan_scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- else:
- feat = x
-
- return self.model(feat)
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
- # This is for backwards compatibility with existing models
- if nr == 3:
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- else:
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
- self.RDBs = nn.Sequential(*RDB_list)
-
- def forward(self, x):
- if hasattr(self, 'RDB1'):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- else:
- out = self.RDBs(x)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
- """
-
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
-
- self.noise = GaussianNoise() if gaussian_noise else None
- self.conv1x1 = conv1x1(nf, gc) if plus else None
-
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- if mode == 'CNA':
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- x2 = x2 + self.conv1x1(x)
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- if self.noise:
- return self.noise(x5.mul(0.2) + x)
- else:
- return x5 * 0.2 + x
-
-
-####################
-# ESRGANplus
-####################
-
-class GaussianNoise(nn.Module):
- def __init__(self, sigma=0.1, is_relative_detach=False):
- super().__init__()
- self.sigma = sigma
- self.is_relative_detach = is_relative_detach
- self.noise = torch.tensor(0, dtype=torch.float)
-
- def forward(self, x):
- if self.training and self.sigma != 0:
- self.noise = self.noise.to(x.device)
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
- x = x + sampled_noise
- return x
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# SRVGGNetCompact
-####################
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- This class is copied from https://github.com/xinntao/Real-ESRGAN
- """
-
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
- self.num_in_ch = num_in_ch
- self.num_out_ch = num_out_ch
- self.num_feat = num_feat
- self.num_conv = num_conv
- self.upscale = upscale
- self.act_type = act_type
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
- # the first activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the body structure
- for _ in range(num_conv):
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
- # activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the last conv
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
- # upsample
- self.upsampler = nn.PixelShuffle(upscale)
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
- out += base
- return out
-
-
-####################
-# Upsampler
-####################
-
-class Upsample(nn.Module):
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form
- `minibatch x channels x [optional depth] x [optional height] x width`.
- """
-
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
- if isinstance(scale_factor, tuple):
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
- else:
- self.scale_factor = float(scale_factor) if scale_factor else None
- self.mode = mode
- self.size = size
- self.align_corners = align_corners
-
- def forward(self, x):
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
-
- def extra_repr(self):
- if self.scale_factor is not None:
- info = 'scale_factor=' + str(self.scale_factor)
- else:
- info = 'size=' + str(self.size)
- info += ', mode=' + self.mode
- return info
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-
-def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
- """ Upconv layer """
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
- return sequential(upsample, conv)
-
-
-
-
-
-
-
-
-####################
-# Basic blocks
-####################
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
- Args:
- basic_block (nn.module): nn.module class for basic block. (block)
- num_basic_block (int): number of blocks. (n_layers)
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
- """ activation helper """
- act_type = act_type.lower()
- if act_type == 'relu':
- layer = nn.ReLU(inplace)
- elif act_type in ('leakyrelu', 'lrelu'):
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == 'prelu':
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- elif act_type == 'tanh': # [-1, 1] range output
- layer = nn.Tanh()
- elif act_type == 'sigmoid': # [0, 1] range output
- layer = nn.Sigmoid()
- else:
- raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
- return layer
-
-
-class Identity(nn.Module):
- def __init__(self, *kwargs):
- super(Identity, self).__init__()
-
- def forward(self, x, *kwargs):
- return x
-
-
-def norm(norm_type, nc):
- """ Return a normalization layer """
- norm_type = norm_type.lower()
- if norm_type == 'batch':
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == 'instance':
- layer = nn.InstanceNorm2d(nc, affine=False)
- elif norm_type == 'none':
- def norm_layer(x): return Identity()
- else:
- raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
- return layer
-
-
-def pad(pad_type, padding):
- """ padding layer helper """
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == 'reflect':
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- layer = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- layer = nn.ZeroPad2d(padding)
- else:
- raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ShortcutBlock(nn.Module):
- """ Elementwise sum the output of a submodule to its input """
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
-
-
-def sequential(*args):
- """ Flatten Sequential. It unwraps nn.Sequential. """
- if len(args) == 1:
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError('sequential does not support OrderedDict input.')
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
- spectral_norm=False):
- """ Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
- padding = padding if pad_type == 'zero' else 0
-
- if convtype=='PartialConv2D':
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='DeformConv2D':
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='Conv3D':
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- else:
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
-
- if spectral_norm:
- c = nn.utils.spectral_norm(c)
-
- a = act(act_type) if act_type else None
- if 'CNA' in mode:
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == 'NAC':
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
deleted file mode 100644
index 3eef9eaf65d750bf06e1957cee7d2d468d10caa5..0000000000000000000000000000000000000000
--- a/modules/extensions.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import os
-import sys
-import traceback
-
-import time
-import git
-
-from modules import paths, shared
-
-extensions = []
-extensions_dir = os.path.join(paths.data_path, "extensions")
-extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
-
-if not os.path.exists(extensions_dir):
- os.makedirs(extensions_dir)
-
-def active():
- return [x for x in extensions if x.enabled]
-
-
-class Extension:
- def __init__(self, name, path, enabled=True, is_builtin=False):
- self.name = name
- self.path = path
- self.enabled = enabled
- self.status = ''
- self.can_update = False
- self.is_builtin = is_builtin
- self.version = ''
-
- repo = None
- try:
- if os.path.exists(os.path.join(path, ".git")):
- repo = git.Repo(path)
- except Exception:
- print(f"Error reading github repository info from {path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- if repo is None or repo.bare:
- self.remote = None
- else:
- try:
- self.remote = next(repo.remote().urls, None)
- self.status = 'unknown'
- head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
-
- except Exception:
- self.remote = None
-
- def list_files(self, subdir, extension):
- from modules import scripts
-
- dirpath = os.path.join(self.path, subdir)
- if not os.path.isdir(dirpath):
- return []
-
- res = []
- for filename in sorted(os.listdir(dirpath)):
- res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
-
- res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
-
- return res
-
- def check_updates(self):
- repo = git.Repo(self.path)
- for fetch in repo.remote().fetch("--dry-run"):
- if fetch.flags != fetch.HEAD_UPTODATE:
- self.can_update = True
- self.status = "behind"
- return
-
- self.can_update = False
- self.status = "latest"
-
- def fetch_and_reset_hard(self):
- repo = git.Repo(self.path)
- # Fix: `error: Your local changes to the following files would be overwritten by merge`,
- # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
- repo.git.fetch('--all')
- repo.git.reset('--hard', 'origin')
-
-
-def list_extensions():
- extensions.clear()
-
- if not os.path.isdir(extensions_dir):
- return
-
- paths = []
- for dirname in [extensions_dir, extensions_builtin_dir]:
- if not os.path.isdir(dirname):
- return
-
- for extension_dirname in sorted(os.listdir(dirname)):
- path = os.path.join(dirname, extension_dirname)
- if not os.path.isdir(path):
- continue
-
- paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
-
- for dirname, path, is_builtin in paths:
- extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
- extensions.append(extension)
-
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
deleted file mode 100644
index 1978673d7d98d2df5fba6784ae513411733c0c49..0000000000000000000000000000000000000000
--- a/modules/extra_networks.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import re
-from collections import defaultdict
-
-from modules import errors
-
-extra_network_registry = {}
-
-
-def initialize():
- extra_network_registry.clear()
-
-
-def register_extra_network(extra_network):
- extra_network_registry[extra_network.name] = extra_network
-
-
-class ExtraNetworkParams:
- def __init__(self, items=None):
- self.items = items or []
-
-
-class ExtraNetwork:
- def __init__(self, name):
- self.name = name
-
- def activate(self, p, params_list):
- """
- Called by processing on every run. Whatever the extra network is meant to do should be activated here.
- Passes arguments related to this extra network in params_list.
- User passes arguments by specifying this in his prompt:
-
-
-
- Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
- separated by colon.
-
- Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
- in this case, all effects of this extra networks should be disabled.
-
- Can be called multiple times before deactivate() - each new call should override the previous call completely.
-
- For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
-
- > "1girl, "
-
- params_list will be:
-
- [
- ExtraNetworkParams(items=["agm", "1.1"]),
- ExtraNetworkParams(items=["ray"])
- ]
-
- """
- raise NotImplementedError
-
- def deactivate(self, p):
- """
- Called at the end of processing for housekeeping. No need to do anything here.
- """
-
- raise NotImplementedError
-
-
-def activate(p, extra_network_data):
- """call activate for extra networks in extra_network_data in specified order, then call
- activate for all remaining registered networks with an empty argument list"""
-
- for extra_network_name, extra_network_args in extra_network_data.items():
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- print(f"Skipping unknown extra network: {extra_network_name}")
- continue
-
- try:
- extra_network.activate(p, extra_network_args)
- except Exception as e:
- errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
-
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
-
- try:
- extra_network.activate(p, [])
- except Exception as e:
- errors.display(e, f"activating extra network {extra_network_name}")
-
-
-def deactivate(p, extra_network_data):
- """call deactivate for extra networks in extra_network_data in specified order, then call
- deactivate for all remaining registered networks"""
-
- for extra_network_name, extra_network_args in extra_network_data.items():
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- continue
-
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating extra network {extra_network_name}")
-
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
-
- try:
- extra_network.deactivate(p)
- except Exception as e:
- errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
-
-
-re_extra_net = re.compile(r"<(\w+):([^>]+)>")
-
-
-def parse_prompt(prompt):
- res = defaultdict(list)
-
- def found(m):
- name = m.group(1)
- args = m.group(2)
-
- res[name].append(ExtraNetworkParams(items=args.split(":")))
-
- return ""
-
- prompt = re.sub(re_extra_net, found, prompt)
-
- return prompt, res
-
-
-def parse_prompts(prompts):
- res = []
- extra_data = None
-
- for prompt in prompts:
- updated_prompt, parsed_extra_data = parse_prompt(prompt)
-
- if extra_data is None:
- extra_data = parsed_extra_data
-
- res.append(updated_prompt)
-
- return res, extra_data
-
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
deleted file mode 100644
index d3a4d7adcb7d87d3401752db49657b37735da8bc..0000000000000000000000000000000000000000
--- a/modules/extra_networks_hypernet.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from modules import extra_networks, shared, extra_networks
-from modules.hypernetworks import hypernetwork
-
-
-class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
- def __init__(self):
- super().__init__('hypernet')
-
- def activate(self, p, params_list):
- additional = shared.opts.sd_hypernetwork
-
- if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
- p.all_prompts = [x + f"" for x in p.all_prompts]
- params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
-
- names = []
- multipliers = []
- for params in params_list:
- assert len(params.items) > 0
-
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
-
- hypernetwork.load_hypernetworks(names, multipliers)
-
- def deactivate(self, p):
- pass
diff --git a/modules/extras.py b/modules/extras.py
deleted file mode 100644
index bd637679342491c4b53e9afc37f9083ab050f55a..0000000000000000000000000000000000000000
--- a/modules/extras.py
+++ /dev/null
@@ -1,258 +0,0 @@
-import os
-import re
-import shutil
-
-
-import torch
-import tqdm
-
-from modules import shared, images, sd_models, sd_vae, sd_models_config
-from modules.ui_common import plaintext_to_html
-import gradio as gr
-import safetensors.torch
-
-
-def run_pnginfo(image):
- if image is None:
- return '', '', ''
-
- geninfo, items = images.read_info_from_image(image)
- items = {**{'parameters': geninfo}, **items}
-
- info = ''
- for key, text in items.items():
- info += f"""
-
-
{plaintext_to_html(str(key))}
-
{plaintext_to_html(str(text))}
-
-""".strip()+"\n"
-
- if len(info) == 0:
- message = "Nothing found in the image."
- info = f""
-
- return '', geninfo, info
-
-
-def create_config(ckpt_result, config_source, a, b, c):
- def config(x):
- res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
- return res if res != shared.sd_default_config else None
-
- if config_source == 0:
- cfg = config(a) or config(b) or config(c)
- elif config_source == 1:
- cfg = config(b)
- elif config_source == 2:
- cfg = config(c)
- else:
- cfg = None
-
- if cfg is None:
- return
-
- filename, _ = os.path.splitext(ckpt_result)
- checkpoint_filename = filename + ".yaml"
-
- print("Copying config:")
- print(" from:", cfg)
- print(" to:", checkpoint_filename)
- shutil.copyfile(cfg, checkpoint_filename)
-
-
-checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
-
-def to_half(tensor, enable):
- if enable and tensor.dtype == torch.float:
- return tensor.half()
-
- return tensor
-
-
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
- shared.state.begin()
- shared.state.job = 'model-merge'
-
- def fail(message):
- shared.state.textinfo = message
- shared.state.end()
- return [*[gr.update() for _ in range(4)], message]
-
- def weighted_sum(theta0, theta1, alpha):
- return ((1 - alpha) * theta0) + (alpha * theta1)
-
- def get_difference(theta1, theta2):
- return theta1 - theta2
-
- def add_difference(theta0, theta1_2_diff, alpha):
- return theta0 + (alpha * theta1_2_diff)
-
- def filename_weighted_sum():
- a = primary_model_info.model_name
- b = secondary_model_info.model_name
- Ma = round(1 - multiplier, 2)
- Mb = round(multiplier, 2)
-
- return f"{Ma}({a}) + {Mb}({b})"
-
- def filename_add_difference():
- a = primary_model_info.model_name
- b = secondary_model_info.model_name
- c = tertiary_model_info.model_name
- M = round(multiplier, 2)
-
- return f"{a} + {M}({b} - {c})"
-
- def filename_nothing():
- return primary_model_info.model_name
-
- theta_funcs = {
- "Weighted sum": (filename_weighted_sum, None, weighted_sum),
- "Add difference": (filename_add_difference, get_difference, add_difference),
- "No interpolation": (filename_nothing, None, None),
- }
- filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
- shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
-
- if not primary_model_name:
- return fail("Failed: Merging requires a primary model.")
-
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
-
- if theta_func2 and not secondary_model_name:
- return fail("Failed: Merging requires a secondary model.")
-
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
-
- if theta_func1 and not tertiary_model_name:
- return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
-
- tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
-
- result_is_inpainting_model = False
- result_is_instruct_pix2pix_model = False
-
- if theta_func2:
- shared.state.textinfo = f"Loading B"
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cuda')
- else:
- theta_1 = None
-
- if theta_func1:
- shared.state.textinfo = f"Loading C"
- print(f"Loading {tertiary_model_info.filename}...")
- theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cuda')
-
- shared.state.textinfo = 'Merging B and C'
- shared.state.sampling_steps = len(theta_1.keys())
- for key in tqdm.tqdm(theta_1.keys()):
- if key in checkpoint_dict_skip_on_merge:
- continue
-
- if 'model' in key:
- if key in theta_2:
- t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
- theta_1[key] = theta_func1(theta_1[key], t2)
- else:
- theta_1[key] = torch.zeros_like(theta_1[key])
-
- shared.state.sampling_step += 1
- del theta_2
-
- shared.state.nextjob()
-
- shared.state.textinfo = f"Loading {primary_model_info.filename}..."
- print(f"Loading {primary_model_info.filename}...")
- theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cuda')
-
- print("Merging...")
- shared.state.textinfo = 'Merging A and B'
- shared.state.sampling_steps = len(theta_0.keys())
- for key in tqdm.tqdm(theta_0.keys()):
- if theta_1 and 'model' in key and key in theta_1:
-
- if key in checkpoint_dict_skip_on_merge:
- continue
-
- a = theta_0[key]
- b = theta_1[key]
-
- # this enables merging an inpainting model (A) with another one (B);
- # where normal model would have 4 channels, for latenst space, inpainting model would
- # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
- if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
- if a.shape[1] == 4 and b.shape[1] == 9:
- raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
- if a.shape[1] == 4 and b.shape[1] == 8:
- raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
-
- if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
- result_is_instruct_pix2pix_model = True
- else:
- assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
- result_is_inpainting_model = True
- else:
- theta_0[key] = theta_func2(a, b, multiplier)
-
- theta_0[key] = to_half(theta_0[key], save_as_half)
-
- shared.state.sampling_step += 1
-
- del theta_1
-
- bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
- if bake_in_vae_filename is not None:
- print(f"Baking in VAE from {bake_in_vae_filename}")
- shared.state.textinfo = 'Baking in VAE'
- vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
-
- for key in vae_dict.keys():
- theta_0_key = 'first_stage_model.' + key
- if theta_0_key in theta_0:
- theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
-
- del vae_dict
-
- if save_as_half and not theta_func2:
- for key in theta_0.keys():
- theta_0[key] = to_half(theta_0[key], save_as_half)
-
- if discard_weights:
- regex = re.compile(discard_weights)
- for key in list(theta_0):
- if re.search(regex, key):
- theta_0.pop(key, None)
-
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
-
- filename = filename_generator() if custom_name == '' else custom_name
- filename += ".inpainting" if result_is_inpainting_model else ""
- filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
- filename += "." + checkpoint_format
-
- output_modelname = os.path.join(ckpt_dir, filename)
-
- shared.state.nextjob()
- shared.state.textinfo = "Saving"
- print(f"Saving to {output_modelname}...")
-
- _, extension = os.path.splitext(output_modelname)
- if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
- else:
- torch.save(theta_0, output_modelname)
-
- sd_models.list_models()
-
- create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
-
- print(f"Checkpoint saved to {output_modelname}.")
- shared.state.textinfo = "Checkpoint saved"
- shared.state.end()
-
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/face_restoration.py b/modules/face_restoration.py
deleted file mode 100644
index 4ae53d21bef3e0783481d7c3cf3a9b2fedc4c092..0000000000000000000000000000000000000000
--- a/modules/face_restoration.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from modules import shared
-
-
-class FaceRestoration:
- def name(self):
- return "None"
-
- def restore(self, np_image):
- return np_image
-
-
-def restore_faces(np_image):
- face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
- if len(face_restorers) == 0:
- return np_image
-
- face_restorer = face_restorers[0]
-
- return face_restorer.restore(np_image)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
deleted file mode 100644
index 89dc23bff5f4be8a1feea9d2429b37878192f1e1..0000000000000000000000000000000000000000
--- a/modules/generation_parameters_copypaste.py
+++ /dev/null
@@ -1,402 +0,0 @@
-import base64
-import html
-import io
-import math
-import os
-import re
-from pathlib import Path
-
-import gradio as gr
-from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks
-import tempfile
-from PIL import Image
-
-re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
-re_param = re.compile(re_param_code)
-re_imagesize = re.compile(r"^(\d+)x(\d+)$")
-re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
-type_of_gr_update = type(gr.update())
-
-paste_fields = {}
-registered_param_bindings = []
-
-
-class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
- self.paste_button = paste_button
- self.tabname = tabname
- self.source_text_component = source_text_component
- self.source_image_component = source_image_component
- self.source_tabname = source_tabname
- self.override_settings_component = override_settings_component
-
-
-def reset():
- paste_fields.clear()
-
-
-def quote(text):
- if ',' not in str(text):
- return text
-
- text = str(text)
- text = text.replace('\\', '\\\\')
- text = text.replace('"', '\\"')
- return f'"{text}"'
-
-
-def image_from_url_text(filedata):
- if filedata is None:
- return None
-
- if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
- filedata = filedata[0]
-
- if type(filedata) == dict and filedata.get("is_file", False):
- filename = filedata["name"]
- is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
- assert is_in_right_dir, 'trying to open image file outside of allowed directories'
-
- return Image.open(filename)
-
- if type(filedata) == list:
- if len(filedata) == 0:
- return None
-
- filedata = filedata[0]
-
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- filedata = base64.decodebytes(filedata.encode('utf-8'))
- image = Image.open(io.BytesIO(filedata))
- return image
-
-
-def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
- paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
-
- # backwards compatibility for existing extensions
- import modules.ui
- if tabname == 'txt2img':
- modules.ui.txt2img_paste_fields = fields
- elif tabname == 'img2img':
- modules.ui.img2img_paste_fields = fields
-
-
-def create_buttons(tabs_list):
- buttons = {}
- for tab in tabs_list:
- buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
- return buttons
-
-
-def bind_buttons(buttons, send_image, send_generate_info):
- """old function for backwards compatibility; do not use this, use register_paste_params_button"""
- for tabname, button in buttons.items():
- source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
- source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
-
- register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
-
-
-def register_paste_params_button(binding: ParamBinding):
- registered_param_bindings.append(binding)
-
-
-def connect_paste_params_buttons():
- binding: ParamBinding
- for binding in registered_param_bindings:
- destination_image_component = paste_fields[binding.tabname]["init_img"]
- fields = paste_fields[binding.tabname]["fields"]
- override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
-
- destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
- destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
-
- if binding.source_image_component and destination_image_component:
- if isinstance(binding.source_image_component, gr.Gallery):
- func = send_image_and_dimensions if destination_width_component else image_from_url_text
- jsfunc = "extract_image_from_gallery"
- else:
- func = send_image_and_dimensions if destination_width_component else lambda x: x
- jsfunc = None
-
- binding.paste_button.click(
- fn=func,
- _js=jsfunc,
- inputs=[binding.source_image_component],
- outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
- )
-
- if binding.source_text_component is not None and fields is not None:
- connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
-
- if binding.source_tabname is not None and fields is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
- binding.paste_button.click(
- fn=lambda *x: x,
- inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
- outputs=[field for field, name in fields if name in paste_field_names],
- )
-
- binding.paste_button.click(
- fn=None,
- _js=f"switch_to_{binding.tabname}",
- inputs=None,
- outputs=None,
- )
-
-
-def send_image_and_dimensions(x):
- if isinstance(x, Image.Image):
- img = x
- else:
- img = image_from_url_text(x)
-
- if shared.opts.send_size and isinstance(img, Image.Image):
- w = img.width
- h = img.height
- else:
- w = gr.update()
- h = gr.update()
-
- return img, w, h
-
-
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
-def restore_old_hires_fix_params(res):
- """for infotexts that specify old First pass size parameter, convert it into
- width, height, and hr scale"""
-
- firstpass_width = res.get('First pass size-1', None)
- firstpass_height = res.get('First pass size-2', None)
-
- if shared.opts.use_old_hires_fix_width_height:
- hires_width = int(res.get("Hires resize-1", 0))
- hires_height = int(res.get("Hires resize-2", 0))
-
- if hires_width and hires_height:
- res['Size-1'] = hires_width
- res['Size-2'] = hires_height
- return
-
- if firstpass_width is None or firstpass_height is None:
- return
-
- firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
- width = int(res.get("Size-1", 512))
- height = int(res.get("Size-2", 512))
-
- if firstpass_width == 0 or firstpass_height == 0:
- from modules import processing
- firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
-
- res['Size-1'] = firstpass_width
- res['Size-2'] = firstpass_height
- res['Hires resize-1'] = width
- res['Hires resize-2'] = height
-
-
-def parse_generation_parameters(x: str):
- """parses generation parameters string, the one you see in text field under the picture in UI:
-```
-girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
-Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
-Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
-```
-
- returns a dict with field values
- """
-
- res = {}
-
- prompt = ""
- negative_prompt = ""
-
- done_with_prompt = False
-
- *lines, lastline = x.strip().split("\n")
- if len(re_param.findall(lastline)) < 3:
- lines.append(lastline)
- lastline = ''
-
- for i, line in enumerate(lines):
- line = line.strip()
- if line.startswith("Negative prompt:"):
- done_with_prompt = True
- line = line[16:].strip()
-
- if done_with_prompt:
- negative_prompt += ("" if negative_prompt == "" else "\n") + line
- else:
- prompt += ("" if prompt == "" else "\n") + line
-
- res["Prompt"] = prompt
- res["Negative prompt"] = negative_prompt
-
- for k, v in re_param.findall(lastline):
- v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
- m = re_imagesize.match(v)
- if m is not None:
- res[k+"-1"] = m.group(1)
- res[k+"-2"] = m.group(2)
- else:
- res[k] = v
-
- # Missing CLIP skip means it was set to 1 (the default)
- if "Clip skip" not in res:
- res["Clip skip"] = "1"
-
- hypernet = res.get("Hypernet", None)
- if hypernet is not None:
- res["Prompt"] += f""""""
-
- if "Hires resize-1" not in res:
- res["Hires resize-1"] = 0
- res["Hires resize-2"] = 0
-
- restore_old_hires_fix_params(res)
-
- return res
-
-
-settings_map = {}
-
-infotext_to_setting_name_mapping = [
- ('Clip skip', 'CLIP_stop_at_last_layers', ),
- ('Conditional mask weight', 'inpainting_mask_weight'),
- ('Model hash', 'sd_model_checkpoint'),
- ('ENSD', 'eta_noise_seed_delta'),
- ('Noise multiplier', 'initial_noise_multiplier'),
- ('Eta', 'eta_ancestral'),
- ('Eta DDIM', 'eta_ddim'),
- ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
-]
-
-
-def create_override_settings_dict(text_pairs):
- """creates processing's override_settings parameters from gradio's multiselect
-
- Example input:
- ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
-
- Example output:
- {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
- """
-
- res = {}
-
- params = {}
- for pair in text_pairs:
- k, v = pair.split(":", maxsplit=1)
-
- params[k] = v.strip()
-
- for param_name, setting_name in infotext_to_setting_name_mapping:
- value = params.get(param_name, None)
-
- if value is None:
- continue
-
- res[setting_name] = shared.opts.cast_value(setting_name, value)
-
- return res
-
-
-def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
- def paste_func(prompt):
- if not prompt and not shared.cmd_opts.hide_ui_dir_config:
- filename = os.path.join(data_path, "params.txt")
- if os.path.exists(filename):
- with open(filename, "r", encoding="utf8") as file:
- prompt = file.read()
-
- params = parse_generation_parameters(prompt)
- script_callbacks.infotext_pasted_callback(prompt, params)
- res = []
-
- for output, key in paste_fields:
- if callable(key):
- v = key(params)
- else:
- v = params.get(key, None)
-
- if v is None:
- res.append(gr.update())
- elif isinstance(v, type_of_gr_update):
- res.append(v)
- else:
- try:
- valtype = type(output.value)
-
- if valtype == bool and v == "False":
- val = False
- else:
- val = valtype(v)
-
- res.append(gr.update(value=val))
- except Exception:
- res.append(gr.update())
-
- return res
-
- if override_settings_component is not None:
- def paste_settings(params):
- vals = {}
-
- for param_name, setting_name in infotext_to_setting_name_mapping:
- v = params.get(param_name, None)
- if v is None:
- continue
-
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
- continue
-
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
-
- if v == current_value:
- continue
-
- vals[param_name] = v
-
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
-
- return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
-
- paste_fields = paste_fields + [(override_settings_component, paste_settings)]
-
- button.click(
- fn=paste_func,
- _js=f"recalculate_prompts_{tabname}",
- inputs=[input_comp],
- outputs=[x[0] for x in paste_fields],
- )
-
-
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
deleted file mode 100644
index fbe6215a6493c4e0690db689afbdc76e0e69f0b7..0000000000000000000000000000000000000000
--- a/modules/gfpgan_model.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import os
-import sys
-import traceback
-
-import facexlib
-import gfpgan
-
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader
-
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
- model_file = models[0]
- elif len(models) != 0:
- latest_file = max(models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
-
-
-def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
- return np_image
-
-
-gfpgan_constructor = None
-
-
-def setup_model(dirname):
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- try:
- from gfpgan import GFPGANer
- from facexlib import detection, parsing
- global user_path
- global have_gfpgan
- global gfpgan_constructor
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
- except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/hashes.py b/modules/hashes.py
deleted file mode 100644
index 83272a0787db52a0a6f0dfe2e3c8f1f8bf6fa625..0000000000000000000000000000000000000000
--- a/modules/hashes.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import hashlib
-import json
-import os.path
-
-import filelock
-
-from modules import shared
-from modules.paths import data_path
-
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-
-
-def dump_cache():
- with filelock.FileLock(cache_filename+".lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
-
-
-def cache(subsection):
- global cache_data
-
- if cache_data is None:
- with filelock.FileLock(cache_filename+".lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
-
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
-
- return s
-
-
-def calculate_sha256(filename):
- hash_sha256 = hashlib.sha256()
- blksize = 1024 * 1024
-
- with open(filename, "rb") as f:
- for chunk in iter(lambda: f.read(blksize), b""):
- hash_sha256.update(chunk)
-
- return hash_sha256.hexdigest()
-
-
-def sha256_from_cache(filename, title):
- hashes = cache("hashes")
- ondisk_mtime = os.path.getmtime(filename)
-
- if title not in hashes:
- return None
-
- cached_sha256 = hashes[title].get("sha256", None)
- cached_mtime = hashes[title].get("mtime", 0)
-
- if ondisk_mtime > cached_mtime or cached_sha256 is None:
- return None
-
- return cached_sha256
-
-
-def sha256(filename, title):
- hashes = cache("hashes")
-
- sha256_value = sha256_from_cache(filename, title)
- if sha256_value is not None:
- return sha256_value
-
- if shared.cmd_opts.no_hashing:
- return None
-
- print(f"Calculating sha256 for {filename}: ", end='')
- sha256_value = calculate_sha256(filename)
- print(f"{sha256_value}")
-
- hashes[title] = {
- "mtime": os.path.getmtime(filename),
- "sha256": sha256_value,
- }
-
- dump_cache()
-
- return sha256_value
-
-
-
-
-
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
deleted file mode 100644
index f6ef42d5a02704e28170a89b92eceaac0bc4b1dd..0000000000000000000000000000000000000000
--- a/modules/hypernetworks/hypernetwork.py
+++ /dev/null
@@ -1,811 +0,0 @@
-import csv
-import datetime
-import glob
-import html
-import os
-import sys
-import traceback
-import inspect
-
-import modules.textual_inversion.dataset
-import torch
-import tqdm
-from einops import rearrange, repeat
-from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
-from modules.textual_inversion import textual_inversion, logging
-from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from torch import einsum
-from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-
-from collections import defaultdict, deque
-from statistics import stdev, mean
-
-
-optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
-
-class HypernetworkModule(torch.nn.Module):
- activation_dict = {
- "linear": torch.nn.Identity,
- "relu": torch.nn.ReLU,
- "leakyrelu": torch.nn.LeakyReLU,
- "elu": torch.nn.ELU,
- "swish": torch.nn.Hardswish,
- "tanh": torch.nn.Tanh,
- "sigmoid": torch.nn.Sigmoid,
- }
- activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
-
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
- add_layer_norm=False, activate_output=False, dropout_structure=None):
- super().__init__()
-
- self.multiplier = 1.0
-
- assert layer_structure is not None, "layer_structure must not be None"
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
- linears = []
- for i in range(len(layer_structure) - 1):
-
- # Add a fully-connected layer
- linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
-
- # Add an activation func except last layer
- if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
- pass
- elif activation_func in self.activation_dict:
- linears.append(self.activation_dict[activation_func]())
- else:
- raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
-
- # Add layer normalization
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
-
- # Everything should be now parsed into dropout structure, and applied here.
- # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
- if dropout_structure is not None and dropout_structure[i+1] > 0:
- assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
- linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
- # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
-
- self.linear = torch.nn.Sequential(*linears)
-
- if state_dict is not None:
- self.fix_old_state_dict(state_dict)
- self.load_state_dict(state_dict)
- else:
- for layer in self.linear:
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- w, b = layer.weight.data, layer.bias.data
- if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
- normal_(w, mean=0.0, std=0.01)
- normal_(b, mean=0.0, std=0)
- elif weight_init == 'XavierUniform':
- xavier_uniform_(w)
- zeros_(b)
- elif weight_init == 'XavierNormal':
- xavier_normal_(w)
- zeros_(b)
- elif weight_init == 'KaimingUniform':
- kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
- zeros_(b)
- elif weight_init == 'KaimingNormal':
- kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
- zeros_(b)
- else:
- raise KeyError(f"Key {weight_init} is not defined as initialization!")
- self.to(devices.device)
-
- def fix_old_state_dict(self, state_dict):
- changes = {
- 'linear1.bias': 'linear.0.bias',
- 'linear1.weight': 'linear.0.weight',
- 'linear2.bias': 'linear.1.bias',
- 'linear2.weight': 'linear.1.weight',
- }
-
- for fr, to in changes.items():
- x = state_dict.get(fr, None)
- if x is None:
- continue
-
- del state_dict[fr]
- state_dict[to] = x
-
- def forward(self, x):
- return x + self.linear(x) * (self.multiplier if not self.training else 1)
-
- def trainables(self):
- layer_structure = []
- for layer in self.linear:
- if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- layer_structure += [layer.weight, layer.bias]
- return layer_structure
-
-
-#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
-def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
- if layer_structure is None:
- layer_structure = [1, 2, 1]
- if not use_dropout:
- return [0] * len(layer_structure)
- dropout_values = [0]
- dropout_values.extend([0.3] * (len(layer_structure) - 3))
- if last_layer_dropout:
- dropout_values.append(0.3)
- else:
- dropout_values.append(0)
- dropout_values.append(0)
- return dropout_values
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
- self.filename = None
- self.name = name
- self.layers = {}
- self.step = 0
- self.sd_checkpoint = None
- self.sd_checkpoint_name = None
- self.layer_structure = layer_structure
- self.activation_func = activation_func
- self.weight_init = weight_init
- self.add_layer_norm = add_layer_norm
- self.use_dropout = use_dropout
- self.activate_output = activate_output
- self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
- self.dropout_structure = kwargs.get('dropout_structure', None)
- if self.dropout_structure is None:
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- self.optimizer_name = None
- self.optimizer_state_dict = None
- self.optional_info = None
-
- for size in enable_sizes or []:
- self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
- )
- self.eval()
-
- def weights(self):
- res = []
- for k, layers in self.layers.items():
- for layer in layers:
- res += layer.parameters()
- return res
-
- def train(self, mode=True):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.train(mode=mode)
- for param in layer.parameters():
- param.requires_grad = mode
-
- def to(self, device):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.to(device)
-
- return self
-
- def set_multiplier(self, multiplier):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.multiplier = multiplier
-
- return self
-
- def eval(self):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.eval()
- for param in layer.parameters():
- param.requires_grad = False
-
- def save(self, filename):
- state_dict = {}
- optimizer_saved_dict = {}
-
- for k, v in self.layers.items():
- state_dict[k] = (v[0].state_dict(), v[1].state_dict())
-
- state_dict['step'] = self.step
- state_dict['name'] = self.name
- state_dict['layer_structure'] = self.layer_structure
- state_dict['activation_func'] = self.activation_func
- state_dict['is_layer_norm'] = self.add_layer_norm
- state_dict['weight_initialization'] = self.weight_init
- state_dict['sd_checkpoint'] = self.sd_checkpoint
- state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
- state_dict['activate_output'] = self.activate_output
- state_dict['use_dropout'] = self.use_dropout
- state_dict['dropout_structure'] = self.dropout_structure
- state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
- state_dict['optional_info'] = self.optional_info if self.optional_info else None
-
- if self.optimizer_name is not None:
- optimizer_saved_dict['optimizer_name'] = self.optimizer_name
-
- torch.save(state_dict, filename)
- if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = self.shorthash()
- optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
- torch.save(optimizer_saved_dict, filename + '.optim')
-
- def load(self, filename):
- self.filename = filename
- if self.name is None:
- self.name = os.path.splitext(os.path.basename(filename))[0]
-
- state_dict = torch.load(filename, map_location='cpu')
-
- self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- self.optional_info = state_dict.get('optional_info', None)
- self.activation_func = state_dict.get('activation_func', None)
- self.weight_init = state_dict.get('weight_initialization', 'Normal')
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
- self.dropout_structure = state_dict.get('dropout_structure', None)
- self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- self.activate_output = state_dict.get('activate_output', True)
- self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
- # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
- if self.dropout_structure is None:
- self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
-
- if shared.opts.print_hypernet_extra:
- if self.optional_info is not None:
- print(f" INFO:\n {self.optional_info}\n")
-
- print(f" Layer structure: {self.layer_structure}")
- print(f" Activation function: {self.activation_func}")
- print(f" Weight initialization: {self.weight_init}")
- print(f" Layer norm: {self.add_layer_norm}")
- print(f" Dropout usage: {self.use_dropout}" )
- print(f" Activate last layer: {self.activate_output}")
- print(f" Dropout structure: {self.dropout_structure}")
-
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
-
- if self.shorthash() == optimizer_saved_dict.get('hash', None):
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
- else:
- self.optimizer_state_dict = None
- if self.optimizer_state_dict:
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- if shared.opts.print_hypernet_extra:
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
- else:
- self.optimizer_name = "AdamW"
- if shared.opts.print_hypernet_extra:
- print("No saved optimizer exists in checkpoint")
-
- for size, sd in state_dict.items():
- if type(size) == int:
- self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, self.dropout_structure),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.activate_output, self.dropout_structure),
- )
-
- self.name = state_dict.get('name', self.name)
- self.step = state_dict.get('step', 0)
- self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
- self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
- self.eval()
-
- def shorthash(self):
- sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
-
- return sha256[0:10] if sha256 else None
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
- name = os.path.splitext(os.path.basename(filename))[0]
- # Prevent a hypothetical "None.pt" from being listed.
- if name != "None":
- res[name] = filename
- return res
-
-
-def load_hypernetwork(name):
- path = shared.hypernetworks.get(name, None)
-
- if path is None:
- return None
-
- hypernetwork = Hypernetwork()
-
- try:
- hypernetwork.load(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- return None
-
- return hypernetwork
-
-
-def load_hypernetworks(names, multipliers=None):
- already_loaded = {}
-
- for hypernetwork in shared.loaded_hypernetworks:
- if hypernetwork.name in names:
- already_loaded[hypernetwork.name] = hypernetwork
-
- shared.loaded_hypernetworks.clear()
-
- for i, name in enumerate(names):
- hypernetwork = already_loaded.get(name, None)
- if hypernetwork is None:
- hypernetwork = load_hypernetwork(name)
-
- if hypernetwork is None:
- continue
-
- hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
- shared.loaded_hypernetworks.append(hypernetwork)
-
-
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
-def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
-
- if hypernetwork_layers is None:
- return context_k, context_v
-
- if layer is not None:
- layer.hyper_k = hypernetwork_layers[0]
- layer.hyper_v = hypernetwork_layers[1]
-
- context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
- context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
- return context_k, context_v
-
-
-def apply_hypernetworks(hypernetworks, context, layer=None):
- context_k = context
- context_v = context
- for hypernetwork in hypernetworks:
- context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
-
- return context_k, context_v
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
-
-
-def stack_conds(conds):
- if len(conds) == 1:
- return torch.stack(conds)
-
- # same as in reconstruct_multicond_batch
- token_count = max([x.shape[0] for x in conds])
- for i in range(len(conds)):
- if conds[i].shape[0] != token_count:
- last_vector = conds[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
- conds[i] = torch.vstack([conds[i], last_vector_repeated])
-
- return torch.stack(conds)
-
-
-def statistics(data):
- if len(data) < 2:
- std = 0
- else:
- std = stdev(data)
- total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
- recent_data = data[-32:]
- if len(recent_data) < 2:
- std = 0
- else:
- std = stdev(recent_data)
- recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
- return total_information, recent_information
-
-
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- assert name, "Name cannot be empty!"
-
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- if not overwrite_old:
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- if type(layer_structure) == str:
- layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
-
- if use_dropout and dropout_structure and type(dropout_structure) == str:
- dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
- else:
- dropout_structure = [0] * len(layer_structure)
-
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
- name=name,
- enable_sizes=[int(x) for x in enable_sizes],
- layer_structure=layer_structure,
- activation_func=activation_func,
- weight_init=weight_init,
- add_layer_norm=add_layer_norm,
- use_dropout=use_dropout,
- dropout_structure=dropout_structure
- )
- hypernet.save(fn)
-
- shared.reload_hypernetworks()
-
-
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images
-
- save_hypernetwork_every = save_hypernetwork_every or 0
- create_image_every = create_image_every or 0
- template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
- template_file = template_file.path
-
- path = shared.hypernetworks.get(hypernetwork_name, None)
- hypernetwork = Hypernetwork()
- hypernetwork.load(path)
- shared.loaded_hypernetworks = [hypernetwork]
-
- shared.state.job = "train-hypernetwork"
- shared.state.textinfo = "Initializing hypernetwork training..."
- shared.state.job_count = steps
-
- hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
-
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
- unload = shared.opts.unload_models_when_training
-
- if save_hypernetwork_every > 0:
- hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
- os.makedirs(hypernetwork_dir, exist_ok=True)
- else:
- hypernetwork_dir = None
-
- if create_image_every > 0:
- images_dir = os.path.join(log_directory, "images")
- os.makedirs(images_dir, exist_ok=True)
- else:
- images_dir = None
-
- checkpoint = sd_models.select_checkpoint()
-
- initial_step = hypernetwork.step or 0
- if initial_step >= steps:
- shared.state.textinfo = "Model has already been trained beyond specified max steps"
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
- clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
- if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
-
- if shared.opts.training_enable_tensorboard:
- tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
-
- # dataset loading may take a while, so input validations and early returns should be done before this
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
-
- pin_memory = shared.opts.pin_memory
-
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
-
- if shared.opts.save_training_settings_to_txt:
- saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
- **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
- )
- logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
-
- latent_sampling_method = ds.latent_sampling_method
-
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
-
- old_parallel_processing_allowed = shared.parallel_processing_allowed
-
- if unload:
- shared.parallel_processing_allowed = False
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- weights = hypernetwork.weights()
- hypernetwork.train()
-
- # Here we use optimizer from saved HN, or we can specify as UI option.
- if hypernetwork.optimizer_name in optimizer_dict:
- optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
- optimizer_name = hypernetwork.optimizer_name
- else:
- print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
- optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
- optimizer_name = 'AdamW'
-
- if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
- try:
- optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
- except RuntimeError as e:
- print("Cannot resume from saved optimizer!")
- print(e)
-
- scaler = torch.cuda.amp.GradScaler()
-
- batch_size = ds.batch_size
- gradient_step = ds.gradient_step
- # n steps = batch_size * gradient_step * n image processed
- steps_per_epoch = len(ds) // batch_size // gradient_step
- max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
- loss_step = 0
- _loss_step = 0 #internal
- # size = len(ds.indexes)
- # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
- loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
- # losses = torch.zeros((size,))
- # previous_mean_losses = [0]
- # previous_mean_loss = 0
- # print("Mean loss of {} elements".format(size))
-
- steps_without_grad = 0
-
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
-
- pbar = tqdm.tqdm(total=steps - initial_step)
- try:
- sd_hijack_checkpoint.add()
-
- for i in range((steps-initial_step) * gradient_step):
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
- for j, batch in enumerate(dl):
- # works as a drop_last=True for gradient accumulation
- if j == max_steps_per_epoch:
- break
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
-
- if clip_grad:
- clip_grad_sched.step(hypernetwork.step)
-
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- if use_weight:
- w = batch.weight.to(devices.device, non_blocking=pin_memory)
- if tag_drop_out != 0 or shuffle_tags:
- shared.sd_model.cond_stage_model.to(devices.device)
- c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
- shared.sd_model.cond_stage_model.to(devices.cpu)
- else:
- c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
- if use_weight:
- loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
- del w
- else:
- loss = shared.sd_model.forward(x, c)[0] / gradient_step
- del x
- del c
-
- _loss_step += loss.item()
- scaler.scale(loss).backward()
-
- # go back until we reach gradient accumulation steps
- if (j + 1) % gradient_step != 0:
- continue
- loss_logging.append(_loss_step)
- if clip_grad:
- clip_grad(weights, clip_grad_sched.learn_rate)
-
- scaler.step(optimizer)
- scaler.update()
- hypernetwork.step += 1
- pbar.update()
- optimizer.zero_grad(set_to_none=True)
- loss_step = _loss_step
- _loss_step = 0
-
- steps_done = hypernetwork.step + 1
-
- epoch_num = hypernetwork.step // steps_per_epoch
- epoch_step = hypernetwork.step % steps_per_epoch
-
- description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
- pbar.set_description(description)
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
-
-
-
- if shared.opts.training_enable_tensorboard:
- epoch_num = hypernetwork.step // len(ds)
- epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
- mean_loss = sum(loss_logging) / len(loss_logging)
- textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
- "loss": f"{loss_step:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- hypernetwork.eval()
- rng_state = torch.get_rng_state()
- cuda_rng_state = None
- if torch.cuda.is_available():
- cuda_rng_state = torch.cuda.get_rng_state_all()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- p.disable_extra_networks = True
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = batch.cond_text[0]
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
- torch.set_rng_state(rng_state)
- if torch.cuda.is_available():
- torch.cuda.set_rng_state_all(cuda_rng_state)
- hypernetwork.train()
- if image is not None:
- shared.state.assign_current_image(image)
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- textual_inversion.tensorboard_add_image(tensorboard_writer,
- f"Validation at epoch {epoch_num}", image,
- hypernetwork.step)
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
-
- shared.state.job_no = hypernetwork.step
-
- shared.state.textinfo = f"""
-
-Loss: {loss_step:.7f}
-Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
-Last saved hypernetwork: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-
-"""
- except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- finally:
- pbar.leave = False
- pbar.close()
- hypernetwork.eval()
- #report_statistics(loss_dict)
- sd_hijack_checkpoint.remove()
-
-
-
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
- save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
-
- del optimizer
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
- shared.parallel_processing_allowed = old_parallel_processing_allowed
-
- return hypernetwork, filename
-
-def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
- old_hypernetwork_name = hypernetwork.name
- old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
- old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
- try:
- hypernetwork.sd_checkpoint = checkpoint.shorthash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- hypernetwork.name = hypernetwork_name
- hypernetwork.save(filename)
- except:
- hypernetwork.sd_checkpoint = old_sd_checkpoint
- hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
- hypernetwork.name = old_hypernetwork_name
- raise
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
deleted file mode 100644
index 76599f5adeceffa93c38bfb6fd85132e79d85f1c..0000000000000000000000000000000000000000
--- a/modules/hypernetworks/ui.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import html
-import os
-import re
-
-import gradio as gr
-import modules.hypernetworks.hypernetwork
-from modules import devices, sd_hijack, shared
-
-not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
-
-
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
- filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
-
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
-
-
-def train_hypernetwork(*args):
- shared.loaded_hypernetworks = []
-
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
-
- try:
- sd_hijack.undo_optimizations()
-
- hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
-
- res = f"""
-Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
-Hypernetwork saved to {html.escape(filename)}
-"""
- return res, ""
- except Exception:
- raise
- finally:
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
- sd_hijack.apply_optimizations()
-
diff --git a/modules/images.py b/modules/images.py
deleted file mode 100644
index 5b80c23e11e3ca9a6156bcc99c664721ec41ccc1..0000000000000000000000000000000000000000
--- a/modules/images.py
+++ /dev/null
@@ -1,669 +0,0 @@
-import datetime
-import sys
-import traceback
-
-import pytz
-import io
-import math
-import os
-from collections import namedtuple
-import re
-
-import numpy as np
-import piexif
-import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
-from fonts.ttf import Roboto
-import string
-import json
-import hashlib
-
-from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
-
-LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
-
-
-def image_grid(imgs, batch_size=1, rows=None):
- if rows is None:
- if opts.n_rows > 0:
- rows = opts.n_rows
- elif opts.n_rows == 0:
- rows = batch_size
- elif opts.grid_prevent_empty_spots:
- rows = math.floor(math.sqrt(len(imgs)))
- while len(imgs) % rows != 0:
- rows -= 1
- else:
- rows = math.sqrt(len(imgs))
- rows = round(rows)
- if rows > len(imgs):
- rows = len(imgs)
-
- cols = math.ceil(len(imgs) / rows)
-
- params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
- script_callbacks.image_grid_callback(params)
-
- w, h = imgs[0].size
- grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
-
- for i, img in enumerate(params.imgs):
- grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
-
- return grid
-
-
-Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
-
-
-def split_grid(image, tile_w=512, tile_h=512, overlap=64):
- w = image.width
- h = image.height
-
- non_overlap_width = tile_w - overlap
- non_overlap_height = tile_h - overlap
-
- cols = math.ceil((w - overlap) / non_overlap_width)
- rows = math.ceil((h - overlap) / non_overlap_height)
-
- dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
- dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
-
- grid = Grid([], tile_w, tile_h, w, h, overlap)
- for row in range(rows):
- row_images = []
-
- y = int(row * dy)
-
- if y + tile_h >= h:
- y = h - tile_h
-
- for col in range(cols):
- x = int(col * dx)
-
- if x + tile_w >= w:
- x = w - tile_w
-
- tile = image.crop((x, y, x + tile_w, y + tile_h))
-
- row_images.append([x, tile_w, tile])
-
- grid.tiles.append([y, tile_h, row_images])
-
- return grid
-
-
-def combine_grid(grid):
- def make_mask_image(r):
- r = r * 255 / grid.overlap
- r = r.astype(np.uint8)
- return Image.fromarray(r, 'L')
-
- mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
- mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
-
- combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
- for y, h, row in grid.tiles:
- combined_row = Image.new("RGB", (grid.image_w, h))
- for x, w, tile in row:
- if x == 0:
- combined_row.paste(tile, (0, 0))
- continue
-
- combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
- combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
-
- if y == 0:
- combined_image.paste(combined_row, (0, 0))
- continue
-
- combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
- combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
-
- return combined_image
-
-
-class GridAnnotation:
- def __init__(self, text='', is_active=True):
- self.text = text
- self.is_active = is_active
- self.size = None
-
-
-def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
- def wrap(drawing, text, font, line_length):
- lines = ['']
- for word in text.split():
- line = f'{lines[-1]} {word}'.strip()
- if drawing.textlength(line, font=font) <= line_length:
- lines[-1] = line
- else:
- lines.append(word)
- return lines
-
- def get_font(fontsize):
- try:
- return ImageFont.truetype(opts.font or Roboto, fontsize)
- except Exception:
- return ImageFont.truetype(Roboto, fontsize)
-
- def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for i, line in enumerate(lines):
- fnt = initial_fnt
- fontsize = initial_fontsize
- while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
- fontsize -= 1
- fnt = get_font(fontsize)
- drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
-
- if not line.is_active:
- drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
-
- draw_y += line.size[1] + line_spacing
-
- fontsize = (width + height) // 25
- line_spacing = fontsize // 2
-
- fnt = get_font(fontsize)
-
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
- pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
-
- cols = im.width // width
- rows = im.height // height
-
- assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
- assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
-
- calc_img = Image.new("RGB", (1, 1), "white")
- calc_d = ImageDraw.Draw(calc_img)
-
- for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
- items = [] + texts
- texts.clear()
-
- for line in items:
- wrapped = wrap(calc_d, line.text, fnt, allowed_width)
- texts += [GridAnnotation(x, line.is_active) for x in wrapped]
-
- for line in texts:
- bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
- line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
- line.allowed_width = allowed_width
-
- hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
-
- pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
-
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
-
- for row in range(rows):
- for col in range(cols):
- cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
- result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
-
- d = ImageDraw.Draw(result)
-
- for col in range(cols):
- x = pad_left + (width + margin) * col + width / 2
- y = pad_top / 2 - hor_text_heights[col] / 2
-
- draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
-
- for row in range(rows):
- x = pad_left / 2
- y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
-
- draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
-
- return result
-
-
-def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
- prompts = all_prompts[1:]
- boundary = math.ceil(len(prompts) / 2)
-
- prompts_horiz = prompts[:boundary]
- prompts_vert = prompts[boundary:]
-
- hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
- ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
-
- return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
-
-
-def resize_image(resize_mode, im, width, height, upscaler_name=None):
- """
- Resizes an image with the specified resize_mode, width, and height.
-
- Args:
- resize_mode: The mode to use when resizing the image.
- 0: Resize the image to the specified width and height.
- 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
- 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
- im: The image to resize.
- width: The width to resize the image to.
- height: The height to resize the image to.
- upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
- """
-
- upscaler_name = upscaler_name or opts.upscaler_for_img2img
-
- def resize(im, w, h):
- if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
- return im.resize((w, h), resample=LANCZOS)
-
- scale = max(w / im.width, h / im.height)
-
- if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
- assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
-
- upscaler = upscalers[0]
- im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
-
- if im.width != w or im.height != h:
- im = im.resize((w, h), resample=LANCZOS)
-
- return im
-
- if resize_mode == 0:
- res = resize(im, width, height)
-
- elif resize_mode == 1:
- ratio = width / height
- src_ratio = im.width / im.height
-
- src_w = width if ratio > src_ratio else im.width * height // im.height
- src_h = height if ratio <= src_ratio else im.height * width // im.width
-
- resized = resize(im, src_w, src_h)
- res = Image.new("RGB", (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
-
- else:
- ratio = width / height
- src_ratio = im.width / im.height
-
- src_w = width if ratio < src_ratio else im.width * height // im.height
- src_h = height if ratio >= src_ratio else im.height * width // im.width
-
- resized = resize(im, src_w, src_h)
- res = Image.new("RGB", (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
-
- if ratio < src_ratio:
- fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
- elif ratio > src_ratio:
- fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
-
- return res
-
-
-invalid_filename_chars = '<>:"/\\|?*\n'
-invalid_filename_prefix = ' '
-invalid_filename_postfix = ' .'
-re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
-re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
-re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
-max_filename_part_length = 128
-
-
-def sanitize_filename_part(text, replace_spaces=True):
- if text is None:
- return None
-
- if replace_spaces:
- text = text.replace(' ', '_')
-
- text = text.translate({ord(x): '_' for x in invalid_filename_chars})
- text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
- text = text.rstrip(invalid_filename_postfix)
- return text
-
-
-class FilenameGenerator:
- replacements = {
- 'seed': lambda self: self.seed if self.seed is not None else '',
- 'steps': lambda self: self.p and self.p.steps,
- 'cfg': lambda self: self.p and self.p.cfg_scale,
- 'width': lambda self: self.image.width,
- 'height': lambda self: self.image.height,
- 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
- 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
- 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
- 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
- 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
- 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime]
- 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
- 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
- 'prompt': lambda self: sanitize_filename_part(self.prompt),
- 'prompt_no_styles': lambda self: self.prompt_no_style(),
- 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
- 'prompt_words': lambda self: self.prompt_words(),
- }
- default_time_format = '%Y%m%d%H%M%S'
-
- def __init__(self, p, seed, prompt, image):
- self.p = p
- self.seed = seed
- self.prompt = prompt
- self.image = image
-
- def prompt_no_style(self):
- if self.p is None or self.prompt is None:
- return None
-
- prompt_no_style = self.prompt
- for style in shared.prompt_styles.get_style_prompts(self.p.styles):
- if len(style) > 0:
- for part in style.split("{prompt}"):
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
-
- prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
-
- return sanitize_filename_part(prompt_no_style, replace_spaces=False)
-
- def prompt_words(self):
- words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
- if len(words) == 0:
- words = ["empty"]
- return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
-
- def datetime(self, *args):
- time_datetime = datetime.datetime.now()
-
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
- try:
- time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
- except pytz.exceptions.UnknownTimeZoneError as _:
- time_zone = None
-
- time_zone_time = time_datetime.astimezone(time_zone)
- try:
- formatted_time = time_zone_time.strftime(time_format)
- except (ValueError, TypeError) as _:
- formatted_time = time_zone_time.strftime(self.default_time_format)
-
- return sanitize_filename_part(formatted_time, replace_spaces=False)
-
- def apply(self, x):
- res = ''
-
- for m in re_pattern.finditer(x):
- text, pattern = m.groups()
- res += text
-
- if pattern is None:
- continue
-
- pattern_args = []
- while True:
- m = re_pattern_arg.match(pattern)
- if m is None:
- break
-
- pattern, arg = m.groups()
- pattern_args.insert(0, arg)
-
- fun = self.replacements.get(pattern.lower())
- if fun is not None:
- try:
- replacement = fun(self, *pattern_args)
- except Exception:
- replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- if replacement is not None:
- res += str(replacement)
- continue
-
- res += f'[{pattern}]'
-
- return res
-
-
-def get_next_sequence_number(path, basename):
- """
- Determines and returns the next sequence number to use when saving an image in the specified directory.
-
- The sequence starts at 0.
- """
- result = -1
- if basename != '':
- basename = basename + "-"
-
- prefix_length = len(basename)
- for p in os.listdir(path):
- if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
- try:
- result = max(int(l[0]), result)
- except ValueError:
- pass
-
- return result + 1
-
-
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
- """Save an image.
-
- Args:
- image (`PIL.Image`):
- The image to be saved.
- path (`str`):
- The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
- basename (`str`):
- The base filename which will be applied to `filename pattern`.
- seed, prompt, short_filename,
- extension (`str`):
- Image file extension, default is `png`.
- pngsectionname (`str`):
- Specify the name of the section which `info` will be saved in.
- info (`str` or `PngImagePlugin.iTXt`):
- PNG info chunks.
- existing_info (`dict`):
- Additional PNG info. `existing_info == {pngsectionname: info, ...}`
- no_prompt:
- TODO I don't know its meaning.
- p (`StableDiffusionProcessing`)
- forced_filename (`str`):
- If specified, `basename` and filename pattern will be ignored.
- save_to_dirs (bool):
- If true, the image will be saved into a subdirectory of `path`.
-
- Returns: (fullfn, txt_fullfn)
- fullfn (`str`):
- The full path of the saved imaged.
- txt_fullfn (`str` or None):
- If a text file is saved for this image, this will be its full path. Otherwise None.
- """
- namegen = FilenameGenerator(p, seed, prompt, image)
-
- if save_to_dirs is None:
- save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
-
- if save_to_dirs:
- dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
- path = os.path.join(path, dirname)
-
- os.makedirs(path, exist_ok=True)
-
- if forced_filename is None:
- if short_filename or seed is None:
- file_decoration = ""
- elif opts.save_to_dirs:
- file_decoration = opts.samples_filename_pattern or "[seed]"
- else:
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
-
- add_number = opts.save_images_add_number or file_decoration == ''
-
- if file_decoration != "" and add_number:
- file_decoration = "-" + file_decoration
-
- file_decoration = namegen.apply(file_decoration) + suffix
-
- if add_number:
- basecount = get_next_sequence_number(path, basename)
- fullfn = None
- for i in range(500):
- fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- if not os.path.exists(fullfn):
- break
- else:
- fullfn = os.path.join(path, f"{file_decoration}.{extension}")
- else:
- fullfn = os.path.join(path, f"{forced_filename}.{extension}")
-
- pnginfo = existing_info or {}
- if info is not None:
- pnginfo[pnginfo_section_name] = info
-
- params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
- script_callbacks.before_image_saved_callback(params)
-
- image = params.image
- fullfn = params.filename
- info = params.pnginfo.get(pnginfo_section_name, None)
-
- def _atomically_save_image(image_to_save, filename_without_extension, extension):
- # save image with .tmp extension to avoid race condition when another process detects new image in the directory
- temp_file_path = filename_without_extension + ".tmp"
- image_format = Image.registered_extensions()[extension]
-
- if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- if opts.enable_pnginfo:
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
-
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
-
- elif extension.lower() in (".jpg", ".jpeg", ".webp"):
- if image_to_save.mode == 'RGBA':
- image_to_save = image_to_save.convert("RGB")
- elif image_to_save.mode == 'I;16':
- image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
-
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
-
- if opts.enable_pnginfo and info is not None:
- exif_bytes = piexif.dump({
- "Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
- },
- })
-
- piexif.insert(exif_bytes, temp_file_path)
- else:
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
-
- # atomically rename the file with correct extension
- os.replace(temp_file_path, filename_without_extension + extension)
-
- fullfn_without_extension, extension = os.path.splitext(params.filename)
- _atomically_save_image(image, fullfn_without_extension, extension)
-
- image.already_saved_as = fullfn
-
- oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
- if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
- ratio = image.width / image.height
-
- if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
- elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
-
- try:
- _atomically_save_image(image, fullfn_without_extension, ".jpg")
- except Exception as e:
- errors.display(e, "saving image as downscaled JPG")
-
- if opts.save_txt and info is not None:
- txt_fullfn = f"{fullfn_without_extension}.txt"
- with open(txt_fullfn, "w", encoding="utf8") as file:
- file.write(info + "\n")
- else:
- txt_fullfn = None
-
- script_callbacks.image_saved_callback(params)
-
- return fullfn, txt_fullfn
-
-
-def read_info_from_image(image):
- items = image.info or {}
-
- geninfo = items.pop('parameters', None)
-
- if "exif" in items:
- exif = piexif.load(items["exif"])
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
- try:
- exif_comment = piexif.helper.UserComment.load(exif_comment)
- except ValueError:
- exif_comment = exif_comment.decode('utf8', errors="ignore")
-
- if exif_comment:
- items['exif comment'] = exif_comment
- geninfo = exif_comment
-
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
-
- if items.get("Software", None) == "NovelAI":
- try:
- json_info = json.loads(items["Comment"])
- sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
-
- geninfo = f"""{items["Description"]}
-Negative prompt: {json_info["uc"]}
-Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
- except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- return geninfo, items
-
-
-def image_data(data):
- try:
- image = Image.open(io.BytesIO(data))
- textinfo, _ = read_info_from_image(image)
- return textinfo, None
- except Exception:
- pass
-
- try:
- text = data.decode('utf8')
- assert len(text) < 10000
- return text, None
-
- except Exception:
- pass
-
- return '', None
-
-
-def flatten(img, bgcolor):
- """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
-
- if img.mode == "RGBA":
- background = Image.new('RGBA', img.size, bgcolor)
- background.paste(img, mask=img)
- img = background
-
- return img.convert('RGB')
diff --git a/modules/img2img.py b/modules/img2img.py
deleted file mode 100644
index c973b7708dd4381b9cf4bdd7055497a35c39c57d..0000000000000000000000000000000000000000
--- a/modules/img2img.py
+++ /dev/null
@@ -1,184 +0,0 @@
-import math
-import os
-import sys
-import traceback
-
-import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
-
-from modules import devices, sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
-from modules.shared import opts, state
-import modules.shared as shared
-import modules.processing as processing
-from modules.ui import plaintext_to_html
-import modules.images as images
-import modules.scripts
-
-
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
- processing.fix_seed(p)
-
- images = shared.listfiles(input_dir)
-
- is_inpaint_batch = False
- if inpaint_mask_dir:
- inpaint_masks = shared.listfiles(inpaint_mask_dir)
- is_inpaint_batch = len(inpaint_masks) > 0
- if is_inpaint_batch:
- print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
-
- print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
-
- save_normally = output_dir == ''
-
- p.do_not_save_grid = True
- p.do_not_save_samples = not save_normally
-
- state.job_count = len(images) * p.n_iter
-
- for i, image in enumerate(images):
- state.job = f"{i+1} out of {len(images)}"
- if state.skipped:
- state.skipped = False
-
- if state.interrupted:
- break
-
- img = Image.open(image)
- # Use the EXIF orientation of photos taken by smartphones.
- img = ImageOps.exif_transpose(img)
- p.init_images = [img] * p.batch_size
-
- if is_inpaint_batch:
- # try to find corresponding mask for an image using simple filename matching
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
- # if not found use first one ("same mask for all images" use-case)
- if not mask_image_path in inpaint_masks:
- mask_image_path = inpaint_masks[0]
- mask_image = Image.open(mask_image_path)
- p.image_mask = mask_image
-
- proc = modules.scripts.scripts_img2img.run(p, *args)
- if proc is None:
- proc = process_images(p)
-
- for n, processed_image in enumerate(proc.images):
- filename = os.path.basename(image)
-
- if n > 0:
- left, right = os.path.splitext(filename)
- filename = f"{left}-{n}{right}"
-
- if not save_normally:
- os.makedirs(output_dir, exist_ok=True)
- if processed_image.mode == 'RGBA':
- processed_image = processed_image.convert("RGB")
- processed_image.save(os.path.join(output_dir, filename))
-
-
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
- override_settings = create_override_settings_dict(override_settings_texts)
-
- is_batch = mode == 5
-
- if mode == 0: # img2img
- image = init_img.convert("RGB")
- mask = None
- elif mode == 1: # img2img sketch
- image = sketch.convert("RGB")
- mask = None
- elif mode == 2: # inpaint
- image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- image = image.convert("RGB")
- elif mode == 3: # inpaint sketch
- image = inpaint_color_sketch
- orig = inpaint_color_sketch_orig or inpaint_color_sketch
- pred = np.any(np.array(image) != np.array(orig), axis=-1)
- mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
- mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
- blur = ImageFilter.GaussianBlur(mask_blur)
- image = Image.composite(image.filter(blur), orig, mask.filter(blur))
- image = image.convert("RGB")
- elif mode == 4: # inpaint upload mask
- image = init_img_inpaint
- mask = init_mask_inpaint
- else:
- image = None
- mask = None
-
- # Use the EXIF orientation of photos taken by smartphones.
- if image is not None:
- image = ImageOps.exif_transpose(image)
-
- assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
-
- p = StableDiffusionProcessingImg2Img(
- sd_model=shared.sd_model,
- outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
- outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
- prompt=prompt,
- negative_prompt=negative_prompt,
- styles=prompt_styles,
- seed=seed,
- subseed=subseed,
- subseed_strength=subseed_strength,
- seed_resize_from_h=seed_resize_from_h,
- seed_resize_from_w=seed_resize_from_w,
- seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
- batch_size=batch_size,
- n_iter=n_iter,
- steps=steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- restore_faces=restore_faces,
- tiling=tiling,
- init_images=[image],
- mask=mask,
- mask_blur=mask_blur,
- inpainting_fill=inpainting_fill,
- resize_mode=resize_mode,
- denoising_strength=denoising_strength,
- image_cfg_scale=image_cfg_scale,
- inpaint_full_res=inpaint_full_res,
- inpaint_full_res_padding=inpaint_full_res_padding,
- inpainting_mask_invert=inpainting_mask_invert,
- override_settings=override_settings,
- )
-
- p.scripts = modules.scripts.scripts_txt2img
- p.script_args = args
-
- if shared.cmd_opts.enable_console_prompts:
- print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
-
- p.extra_generation_params["Mask blur"] = mask_blur
-
- if is_batch:
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
-
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
-
- processed = Processed(p, [], p.seed, "")
- else:
- processed = modules.scripts.scripts_img2img.run(p, *args)
- if processed is None:
- processed = process_images(p)
-
- p.close()
-
- shared.total_tqdm.clear()
-
- generation_info_js = processed.js()
- if opts.samples_log_stdout:
- print(generation_info_js)
-
- if opts.do_not_show_images:
- processed.images = []
-
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/import_hook.py b/modules/import_hook.py
deleted file mode 100644
index 28c67dfa897abec5eeb4cfac3da79458d6fee278..0000000000000000000000000000000000000000
--- a/modules/import_hook.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import sys
-
-# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
-if "--xformers" not in "".join(sys.argv):
- sys.modules["xformers"] = None
diff --git a/modules/interrogate.py b/modules/interrogate.py
deleted file mode 100644
index cbb806832818134b02452d905f115c3535e07366..0000000000000000000000000000000000000000
--- a/modules/interrogate.py
+++ /dev/null
@@ -1,227 +0,0 @@
-import os
-import sys
-import traceback
-from collections import namedtuple
-from pathlib import Path
-import re
-
-import torch
-import torch.hub
-
-from torchvision import transforms
-from torchvision.transforms.functional import InterpolationMode
-
-import modules.shared as shared
-from modules import devices, paths, shared, lowvram, modelloader, errors
-
-blip_image_eval_size = 384
-clip_model_name = 'ViT-L/14'
-
-Category = namedtuple("Category", ["name", "topn", "items"])
-
-re_topn = re.compile(r"\.top(\d+)\.")
-
-def category_types():
- return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
-
-
-def download_default_clip_interrogate_categories(content_dir):
- print("Downloading CLIP categories...")
-
- tmpdir = content_dir + "_tmp"
- category_types = ["artists", "flavors", "mediums", "movements"]
-
- try:
- os.makedirs(tmpdir)
- for category_type in category_types:
- torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
- os.rename(tmpdir, content_dir)
-
- except Exception as e:
- errors.display(e, "downloading default CLIP interrogate categories")
- finally:
- if os.path.exists(tmpdir):
- os.remove(tmpdir)
-
-
-class InterrogateModels:
- blip_model = None
- clip_model = None
- clip_preprocess = None
- dtype = None
- running_on_cpu = None
-
- def __init__(self, content_dir):
- self.loaded_categories = None
- self.skip_categories = []
- self.content_dir = content_dir
- self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
-
- def categories(self):
- if not os.path.exists(self.content_dir):
- download_default_clip_interrogate_categories(self.content_dir)
-
- if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
- return self.loaded_categories
-
- self.loaded_categories = []
-
- if os.path.exists(self.content_dir):
- self.skip_categories = shared.opts.interrogate_clip_skip_categories
- category_types = []
- for filename in Path(self.content_dir).glob('*.txt'):
- category_types.append(filename.stem)
- if filename.stem in self.skip_categories:
- continue
- m = re_topn.search(filename.stem)
- topn = 1 if m is None else int(m.group(1))
- with open(filename, "r", encoding="utf8") as file:
- lines = [x.strip() for x in file.readlines()]
-
- self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
-
- return self.loaded_categories
-
- def create_fake_fairscale(self):
- class FakeFairscale:
- def checkpoint_wrapper(self):
- pass
-
- sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
-
- def load_blip_model(self):
- self.create_fake_fairscale()
- import models.blip
-
- files = modelloader.load_models(
- model_path=os.path.join(paths.models_path, "BLIP"),
- model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
- ext_filter=[".pth"],
- download_name='model_base_caption_capfilt_large.pth',
- )
-
- blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
- blip_model.eval()
-
- return blip_model
-
- def load_clip_model(self):
- import clip
-
- if self.running_on_cpu:
- model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
- else:
- model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
-
- model.eval()
- model = model.to(devices.device_interrogate)
-
- return model, preprocess
-
- def load(self):
- if self.blip_model is None:
- self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half and not self.running_on_cpu:
- self.blip_model = self.blip_model.half()
-
- self.blip_model = self.blip_model.to(devices.device_interrogate)
-
- if self.clip_model is None:
- self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half and not self.running_on_cpu:
- self.clip_model = self.clip_model.half()
-
- self.clip_model = self.clip_model.to(devices.device_interrogate)
-
- self.dtype = next(self.clip_model.parameters()).dtype
-
- def send_clip_to_ram(self):
- if not shared.opts.interrogate_keep_models_in_memory:
- if self.clip_model is not None:
- self.clip_model = self.clip_model.to(devices.cpu)
-
- def send_blip_to_ram(self):
- if not shared.opts.interrogate_keep_models_in_memory:
- if self.blip_model is not None:
- self.blip_model = self.blip_model.to(devices.cpu)
-
- def unload(self):
- self.send_clip_to_ram()
- self.send_blip_to_ram()
-
- devices.torch_gc()
-
- def rank(self, image_features, text_array, top_count=1):
- import clip
-
- devices.torch_gc()
-
- if shared.opts.interrogate_clip_dict_limit != 0:
- text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
-
- top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
- text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
- text_features /= text_features.norm(dim=-1, keepdim=True)
-
- similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
- for i in range(image_features.shape[0]):
- similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
- similarity /= image_features.shape[0]
-
- top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
- return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
-
- def generate_caption(self, pil_image):
- gpu_image = transforms.Compose([
- transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
-
- with torch.no_grad():
- caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
-
- return caption[0]
-
- def interrogate(self, pil_image):
- res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
- try:
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- devices.torch_gc()
-
- self.load()
-
- caption = self.generate_caption(pil_image)
- self.send_blip_to_ram()
- devices.torch_gc()
-
- res = caption
-
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
-
- with torch.no_grad(), devices.autocast():
- image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
-
- image_features /= image_features.norm(dim=-1, keepdim=True)
-
- for name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
- for match, score in matches:
- if shared.opts.interrogate_return_ranks:
- res += f", ({match}:{score/100:.3f})"
- else:
- res += ", " + match
-
- except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- res += ""
-
- self.unload()
- shared.state.end()
-
- return res
diff --git a/modules/localization.py b/modules/localization.py
deleted file mode 100644
index f6a6f2fbdf24a6e22db5d63e724eae3bc62324f4..0000000000000000000000000000000000000000
--- a/modules/localization.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import json
-import os
-import sys
-import traceback
-
-
-localizations = {}
-
-
-def list_localizations(dirname):
- localizations.clear()
-
- for file in os.listdir(dirname):
- fn, ext = os.path.splitext(file)
- if ext.lower() != ".json":
- continue
-
- localizations[fn] = os.path.join(dirname, file)
-
- from modules import scripts
- for file in scripts.list_scripts("localizations", ".json"):
- fn, ext = os.path.splitext(file.filename)
- localizations[fn] = file.path
-
-
-def localization_js(current_localization_name):
- fn = localizations.get(current_localization_name, None)
- data = {}
- if fn is not None:
- try:
- with open(fn, "r", encoding="utf8") as file:
- data = json.load(file)
- except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- return f"var localization = {json.dumps(data)}\n"
diff --git a/modules/lowvram.py b/modules/lowvram.py
deleted file mode 100644
index 042a0254a166573a57de1d269d3eaa28b4fa011a..0000000000000000000000000000000000000000
--- a/modules/lowvram.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import torch
-from modules import devices
-
-module_in_gpu = None
-cpu = torch.device("cpu")
-
-
-def send_everything_to_cpu():
- global module_in_gpu
-
- if module_in_gpu is not None:
- module_in_gpu.to(cpu)
-
- module_in_gpu = None
-
-
-def setup_for_low_vram(sd_model, use_medvram):
- parents = {}
-
- def send_me_to_gpu(module, _):
- """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
- we add this as forward_pre_hook to a lot of modules and this way all but one of them will
- be in CPU
- """
- global module_in_gpu
-
- module = parents.get(module, module)
-
- if module_in_gpu == module:
- return
-
- if module_in_gpu is not None:
- module_in_gpu.to(cpu)
-
- module.to(devices.device)
- module_in_gpu = module
-
- # see below for register_forward_pre_hook;
- # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
- # useless here, and we just replace those methods
-
- first_stage_model = sd_model.first_stage_model
- first_stage_model_encode = sd_model.first_stage_model.encode
- first_stage_model_decode = sd_model.first_stage_model.decode
-
- def first_stage_model_encode_wrap(x):
- send_me_to_gpu(first_stage_model, None)
- return first_stage_model_encode(x)
-
- def first_stage_model_decode_wrap(z):
- send_me_to_gpu(first_stage_model, None)
- return first_stage_model_decode(z)
-
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
- sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
-
- # register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = first_stage_model_encode_wrap
- sd_model.first_stage_model.decode = first_stage_model_decode_wrap
- if sd_model.depth_model:
- sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
- parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
-
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
-
- if use_medvram:
- sd_model.model.register_forward_pre_hook(send_me_to_gpu)
- else:
- diff_model = sd_model.model.diffusion_model
-
- # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
- # so that only one of them is in GPU at a time
- stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
- diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(devices.device)
- diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
-
- # install hooks for bits of third model
- diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
- for block in diff_model.input_blocks:
- block.register_forward_pre_hook(send_me_to_gpu)
- diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
- for block in diff_model.output_blocks:
- block.register_forward_pre_hook(send_me_to_gpu)
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
deleted file mode 100644
index ddcea53b920d63a6a0b3a00dd3c54b36201ff761..0000000000000000000000000000000000000000
--- a/modules/mac_specific.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import torch
-from modules import paths
-from modules.sd_hijack_utils import CondFunc
-from packaging import version
-
-
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
-def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
- return False
-has_mps = check_for_mps()
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-def cumsum_fix(input, cumsum_func, *args, **kwargs):
- if input.device.type == 'mps':
- output_dtype = kwargs.get('dtype', input.dtype)
- if output_dtype == torch.int64:
- return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
- return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
- return cumsum_func(input, *args, **kwargs)
-
-
-if has_mps:
- # MPS fix for randn in torchsde
- CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
-
- if version.parse(torch.__version__) < version.parse("1.13"):
- # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-
- # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
- CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
- lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
- # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
- CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
- lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
- # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
- CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
- elif version.parse(torch.__version__) > version.parse("1.13.1"):
- cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
- CondFunc('torch.cumsum', cumsum_fix_func, None)
- CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
- CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-
diff --git a/modules/masking.py b/modules/masking.py
deleted file mode 100644
index a5c4d2da521c2728b2285f1e16ef19bf1e804db7..0000000000000000000000000000000000000000
--- a/modules/masking.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from PIL import Image, ImageFilter, ImageOps
-
-
-def get_crop_region(mask, pad=0):
- """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
- For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-
- h, w = mask.shape
-
- crop_left = 0
- for i in range(w):
- if not (mask[:, i] == 0).all():
- break
- crop_left += 1
-
- crop_right = 0
- for i in reversed(range(w)):
- if not (mask[:, i] == 0).all():
- break
- crop_right += 1
-
- crop_top = 0
- for i in range(h):
- if not (mask[i] == 0).all():
- break
- crop_top += 1
-
- crop_bottom = 0
- for i in reversed(range(h)):
- if not (mask[i] == 0).all():
- break
- crop_bottom += 1
-
- return (
- int(max(crop_left-pad, 0)),
- int(max(crop_top-pad, 0)),
- int(min(w - crop_right + pad, w)),
- int(min(h - crop_bottom + pad, h))
- )
-
-
-def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
- """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
-
- x1, y1, x2, y2 = crop_region
-
- ratio_crop_region = (x2 - x1) / (y2 - y1)
- ratio_processing = processing_width / processing_height
-
- if ratio_crop_region > ratio_processing:
- desired_height = (x2 - x1) / ratio_processing
- desired_height_diff = int(desired_height - (y2-y1))
- y1 -= desired_height_diff//2
- y2 += desired_height_diff - desired_height_diff//2
- if y2 >= image_height:
- diff = y2 - image_height
- y2 -= diff
- y1 -= diff
- if y1 < 0:
- y2 -= y1
- y1 -= y1
- if y2 >= image_height:
- y2 = image_height
- else:
- desired_width = (y2 - y1) * ratio_processing
- desired_width_diff = int(desired_width - (x2-x1))
- x1 -= desired_width_diff//2
- x2 += desired_width_diff - desired_width_diff//2
- if x2 >= image_width:
- diff = x2 - image_width
- x2 -= diff
- x1 -= diff
- if x1 < 0:
- x2 -= x1
- x1 -= x1
- if x2 >= image_width:
- x2 = image_width
-
- return x1, y1, x2, y2
-
-
-def fill(image, mask):
- """fills masked regions with colors from image using blur. Not extremely effective."""
-
- image_mod = Image.new('RGBA', (image.width, image.height))
-
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
-
- image_masked = image_masked.convert('RGBa')
-
- for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
- blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
- for _ in range(repeats):
- image_mod.alpha_composite(blurred)
-
- return image_mod.convert("RGB")
-
diff --git a/modules/memmon.py b/modules/memmon.py
deleted file mode 100644
index a7060f58523a0cfc2fa9138954c801fcce00ba49..0000000000000000000000000000000000000000
--- a/modules/memmon.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import threading
-import time
-from collections import defaultdict
-
-import torch
-
-
-class MemUsageMonitor(threading.Thread):
- run_flag = None
- device = None
- disabled = False
- opts = None
- data = None
-
- def __init__(self, name, device, opts):
- threading.Thread.__init__(self)
- self.name = name
- self.device = device
- self.opts = opts
-
- self.daemon = True
- self.run_flag = threading.Event()
- self.data = defaultdict(int)
-
- try:
- torch.cuda.mem_get_info()
- torch.cuda.memory_stats(self.device)
- except Exception as e: # AMD or whatever
- print(f"Warning: caught exception '{e}', memory monitor disabled")
- self.disabled = True
-
- def run(self):
- if self.disabled:
- return
-
- while True:
- self.run_flag.wait()
-
- torch.cuda.reset_peak_memory_stats()
- self.data.clear()
-
- if self.opts.memmon_poll_rate <= 0:
- self.run_flag.clear()
- continue
-
- self.data["min_free"] = torch.cuda.mem_get_info()[0]
-
- while self.run_flag.is_set():
- free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
- self.data["min_free"] = min(self.data["min_free"], free)
-
- time.sleep(1 / self.opts.memmon_poll_rate)
-
- def dump_debug(self):
- print(self, 'recorded data:')
- for k, v in self.read().items():
- print(k, -(v // -(1024 ** 2)))
-
- print(self, 'raw torch memory stats:')
- tm = torch.cuda.memory_stats(self.device)
- for k, v in tm.items():
- if 'bytes' not in k:
- continue
- print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
-
- print(torch.cuda.memory_summary())
-
- def monitor(self):
- self.run_flag.set()
-
- def read(self):
- if not self.disabled:
- free, total = torch.cuda.mem_get_info()
- self.data["free"] = free
- self.data["total"] = total
-
- torch_stats = torch.cuda.memory_stats(self.device)
- self.data["active"] = torch_stats["active.all.current"]
- self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
- self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
- self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
- self.data["system_peak"] = total - self.data["min_free"]
-
- return self.data
-
- def stop(self):
- self.run_flag.clear()
- return self.read()
diff --git a/modules/modelloader.py b/modules/modelloader.py
deleted file mode 100644
index fc3f6249f1ccb53c279f3e86d3ea95a4a7d03e50..0000000000000000000000000000000000000000
--- a/modules/modelloader.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import glob
-import os
-import shutil
-import importlib
-from urllib.parse import urlparse
-
-from basicsr.utils.download_util import load_file_from_url
-from modules import shared
-from modules.upscaler import Upscaler
-from modules.paths import script_path, models_path
-
-
-def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
- """
- A one-and done loader to try finding the desired models in specified directories.
-
- @param download_name: Specify to download from model_url immediately.
- @param model_url: If no other models are found, this will be downloaded on upscale.
- @param model_path: The location to store/find models in.
- @param command_path: A command-line argument to search for models in first.
- @param ext_filter: An optional list of filename extensions to filter by
- @return: A list of paths containing the desired model(s)
- """
- output = []
-
- if ext_filter is None:
- ext_filter = []
-
- try:
- places = []
-
- if command_path is not None and command_path != model_path:
- pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
- if os.path.exists(pretrained_path):
- print(f"Appending path: {pretrained_path}")
- places.append(pretrained_path)
- elif os.path.exists(command_path):
- places.append(command_path)
-
- places.append(model_path)
-
- for place in places:
- if os.path.exists(place):
- for file in glob.iglob(place + '**/**', recursive=True):
- full_path = file
- if os.path.isdir(full_path):
- continue
- if os.path.islink(full_path) and not os.path.exists(full_path):
- print(f"Skipping broken symlink: {full_path}")
- continue
- if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
- continue
- if len(ext_filter) != 0:
- model_name, extension = os.path.splitext(file)
- if extension not in ext_filter:
- continue
- if file not in output:
- output.append(full_path)
-
- if model_url is not None and len(output) == 0:
- if download_name is not None:
- dl = load_file_from_url(model_url, model_path, True, download_name)
- output.append(dl)
- else:
- output.append(model_url)
-
- except Exception:
- pass
-
- return output
-
-
-def friendly_name(file: str):
- if "http" in file:
- file = urlparse(file).path
-
- file = os.path.basename(file)
- model_name, extension = os.path.splitext(file)
- return model_name
-
-
-def cleanup_models():
- # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
- # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
- # somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
- move_files(src_path, dest_path, ".ckpt")
- move_files(src_path, dest_path, ".safetensors")
- src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path, ".pth")
- src_path = os.path.join(root_path, "gfpgan")
- dest_path = os.path.join(models_path, "GFPGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
- move_files(src_path, dest_path)
-
-
-def move_files(src_path: str, dest_path: str, ext_filter: str = None):
- try:
- if not os.path.exists(dest_path):
- os.makedirs(dest_path)
- if os.path.exists(src_path):
- for file in os.listdir(src_path):
- fullpath = os.path.join(src_path, file)
- if os.path.isfile(fullpath):
- if ext_filter is not None:
- if ext_filter not in file:
- continue
- print(f"Moving {file} from {src_path} to {dest_path}.")
- try:
- shutil.move(fullpath, dest_path)
- except:
- pass
- if len(os.listdir(src_path)) == 0:
- print(f"Removing empty folder: {src_path}")
- shutil.rmtree(src_path, True)
- except:
- pass
-
-
-builtin_upscaler_classes = []
-forbidden_upscaler_classes = set()
-
-
-def list_builtin_upscalers():
- load_upscalers()
-
- builtin_upscaler_classes.clear()
- builtin_upscaler_classes.extend(Upscaler.__subclasses__())
-
-
-def forbid_loaded_nonbuiltin_upscalers():
- for cls in Upscaler.__subclasses__():
- if cls not in builtin_upscaler_classes:
- forbidden_upscaler_classes.add(cls)
-
-
-def load_upscalers():
- # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
- # so we'll try to import any _model.py files before looking in __subclasses__
- modules_dir = os.path.join(shared.script_path, "modules")
- for file in os.listdir(modules_dir):
- if "_model.py" in file:
- model_name = file.replace("_model.py", "")
- full_model = f"modules.{model_name}_model"
- try:
- importlib.import_module(full_model)
- except:
- pass
-
- datas = []
- commandline_options = vars(shared.cmd_opts)
- for cls in Upscaler.__subclasses__():
- if cls in forbidden_upscaler_classes:
- continue
-
- name = cls.__name__
- cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
- scaler = cls(commandline_options.get(cmd_name, None))
- datas += scaler.scalers
-
- shared.sd_upscalers = datas
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
deleted file mode 100644
index f3d49c44cafcc78e27a1e4f2b522faa21e135f9f..0000000000000000000000000000000000000000
--- a/modules/models/diffusion/ddpm_edit.py
+++ /dev/null
@@ -1,1459 +0,0 @@
-"""
-wild mixture of
-https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
-https://github.com/CompVis/taming-transformers
--- merci
-"""
-
-# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
-# See more details in LICENSE.
-
-import torch
-import torch.nn as nn
-import numpy as np
-import pytorch_lightning as pl
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
-from contextlib import contextmanager
-from functools import partial
-from tqdm import tqdm
-from torchvision.utils import make_grid
-from pytorch_lightning.utilities.distributed import rank_zero_only
-
-from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from ldm.models.diffusion.ddim import DDIMSampler
-
-
-__conditioning_keys__ = {'concat': 'c_concat',
- 'crossattn': 'c_crossattn',
- 'adm': 'y'}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def uniform_on_device(r1, r2, shape, device):
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
-
-
-class DDPM(pl.LightningModule):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- image_size=256,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.,
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.,
- load_ema=True,
- ):
- super().__init__()
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
- self.parameterization = parameterization
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
- self.cond_stage_model = None
- self.clip_denoised = clip_denoised
- self.log_every_t = log_every_t
- self.first_stage_key = first_stage_key
- self.image_size = image_size # try conv?
- self.channels = channels
- self.use_positional_encodings = use_positional_encodings
- self.model = DiffusionWrapper(unet_config, conditioning_key)
- count_params(self.model, verbose=True)
- self.use_ema = use_ema
-
- self.use_scheduler = scheduler_config is not None
- if self.use_scheduler:
- self.scheduler_config = scheduler_config
-
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
-
- if monitor is not None:
- self.monitor = monitor
-
- if self.use_ema and load_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
-
- # If initialing from EMA-only checkpoint, create EMA model after loading.
- if self.use_ema and not load_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
-
- self.loss_type = loss_type
-
- self.learn_logvar = learn_logvar
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
- if self.learn_logvar:
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
-
-
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
- cosine_s=cosine_s)
- alphas = 1. - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
-
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer('betas', to_torch(betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
- self.register_buffer('posterior_mean_coef1', to_torch(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
- self.register_buffer('posterior_mean_coef2', to_torch(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
-
- if self.parameterization == "eps":
- lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
- elif self.parameterization == "x0":
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
- else:
- raise NotImplementedError("mu not supported")
- # TODO how to choose this term
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.model.parameters())
- self.model_ema.copy_to(self.model)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.model.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
-
- # Our model adds additional channels to the first layer to condition on an input image.
- # For the first layer, copy existing channel weights and initialize new channel weights to zero.
- input_keys = [
- "model.diffusion_model.input_blocks.0.0.weight",
- "model_ema.diffusion_modelinput_blocks00weight",
- ]
-
- self_sd = self.state_dict()
- for input_key in input_keys:
- if input_key not in sd or input_key not in self_sd:
- continue
-
- input_weight = self_sd[input_key]
-
- if input_weight.size() != sd[input_key].size():
- print(f"Manual init: {input_key}")
- input_weight.zero_()
- input_weight[:, :4, :, :].copy_(sd[input_key])
- ignore_keys.append(input_key)
-
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
- return mean, variance, log_variance
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_out = self.model(x, t)
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def p_sample_loop(self, shape, return_intermediates=False):
- device = self.betas.device
- b = shape[0]
- img = torch.randn(shape, device=device)
- intermediates = [img]
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
- clip_denoised=self.clip_denoised)
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
- intermediates.append(img)
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16, return_intermediates=False):
- image_size = self.image_size
- channels = self.channels
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
- return_intermediates=return_intermediates)
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
-
- def get_loss(self, pred, target, mean=True):
- if self.loss_type == 'l1':
- loss = (target - pred).abs()
- if mean:
- loss = loss.mean()
- elif self.loss_type == 'l2':
- if mean:
- loss = torch.nn.functional.mse_loss(target, pred)
- else:
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
- else:
- raise NotImplementedError("unknown loss type '{loss_type}'")
-
- return loss
-
- def p_losses(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_out = self.model(x_noisy, t)
-
- loss_dict = {}
- if self.parameterization == "eps":
- target = noise
- elif self.parameterization == "x0":
- target = x_start
- else:
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
-
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
-
- log_prefix = 'train' if self.training else 'val'
-
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
- loss_simple = loss.mean() * self.l_simple_weight
-
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
-
- loss = loss_simple + self.original_elbo_weight * loss_vlb
-
- loss_dict.update({f'{log_prefix}/loss': loss})
-
- return loss, loss_dict
-
- def forward(self, x, *args, **kwargs):
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- return self.p_losses(x, t, *args, **kwargs)
-
- def get_input(self, batch, k):
- return batch[k]
-
- def shared_step(self, batch):
- x = self.get_input(batch, self.first_stage_key)
- loss, loss_dict = self(x)
- return loss, loss_dict
-
- def training_step(self, batch, batch_idx):
- loss, loss_dict = self.shared_step(batch)
-
- self.log_dict(loss_dict, prog_bar=True,
- logger=True, on_step=True, on_epoch=True)
-
- self.log("global_step", self.global_step,
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- if self.use_scheduler:
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- return loss
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- _, loss_dict_no_ema = self.shared_step(batch)
- with self.ema_scope():
- _, loss_dict_ema = self.shared_step(batch)
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self.model)
-
- def _get_rows_from_list(self, samples):
- n_imgs_per_row = len(samples)
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
- x = self.get_input(batch, self.first_stage_key)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- x = x.to(self.device)[:N]
- log["inputs"] = x
-
- # get diffusion row
- diffusion_row = list()
- x_start = x[:n_row]
-
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(x_start)
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- diffusion_row.append(x_noisy)
-
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
-
- log["samples"] = samples
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.learn_logvar:
- params = params + [self.logvar]
- opt = torch.optim.AdamW(params, lr=lr)
- return opt
-
-
-class LatentDiffusion(DDPM):
- """main class"""
- def __init__(self,
- first_stage_config,
- cond_stage_config,
- num_timesteps_cond=None,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- cond_stage_forward=None,
- conditioning_key=None,
- scale_factor=1.0,
- scale_by_std=False,
- load_ema=True,
- *args, **kwargs):
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
- self.scale_by_std = scale_by_std
- assert self.num_timesteps_cond <= kwargs['timesteps']
- # for backwards compatibility after implementation of DiffusionWrapper
- if conditioning_key is None:
- conditioning_key = 'concat' if concat_mode else 'crossattn'
- if cond_stage_config == '__is_unconditional__':
- conditioning_key = None
- ckpt_path = kwargs.pop("ckpt_path", None)
- ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- try:
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
- self.num_downs = 0
- if not scale_by_std:
- self.scale_factor = scale_factor
- else:
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
- self.instantiate_first_stage(first_stage_config)
- self.instantiate_cond_stage(cond_stage_config)
- self.cond_stage_forward = cond_stage_forward
- self.clip_denoised = False
- self.bbox_tokenizer = None
-
- self.restarted_from_ckpt = False
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys)
- self.restarted_from_ckpt = True
-
- if self.use_ema and not load_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- def make_cond_schedule(self, ):
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
- self.cond_ids[:self.num_timesteps_cond] = ids
-
- @rank_zero_only
- @torch.no_grad()
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
- # only for very first batch
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
- # set rescale weight to 1./std of encodings
- print("### USING STD-RESCALING ###")
- x = super().get_input(batch, self.first_stage_key)
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- del self.scale_factor
- self.register_buffer('scale_factor', 1. / z.flatten().std())
- print(f"setting self.scale_factor to {self.scale_factor}")
- print("### USING STD-RESCALING ###")
-
- def register_schedule(self,
- given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
-
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
-
- def instantiate_first_stage(self, config):
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()
- self.first_stage_model.train = disabled_train
- for param in self.first_stage_model.parameters():
- param.requires_grad = False
-
- def instantiate_cond_stage(self, config):
- if not self.cond_stage_trainable:
- if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
- self.cond_stage_model = self.first_stage_model
- elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
- self.cond_stage_model = None
- # self.be_unconditional = True
- else:
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()
- self.cond_stage_model.train = disabled_train
- for param in self.cond_stage_model.parameters():
- param.requires_grad = False
- else:
- assert config != '__is_first_stage__'
- assert config != '__is_unconditional__'
- model = instantiate_from_config(config)
- self.cond_stage_model = model
-
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
- denoise_row = []
- for zd in tqdm(samples, desc=desc):
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
- n_imgs_per_row = len(denoise_row)
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
- return self.scale_factor * z
-
- def get_learned_conditioning(self, c):
- if self.cond_stage_forward is None:
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
- c = self.cond_stage_model.encode(c)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- else:
- c = self.cond_stage_model(c)
- else:
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
- return c
-
- def meshgrid(self, h, w):
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
-
- arr = torch.cat([y, x], dim=-1)
- return arr
-
- def delta_border(self, h, w):
- """
- :param h: height
- :param w: width
- :return: normalized distance to image border,
- wtith min distance = 0 at border and max dist = 0.5 at image center
- """
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
- arr = self.meshgrid(h, w) / lower_right_corner
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
- return edge_dist
-
- def get_weighting(self, h, w, Ly, Lx, device):
- weighting = self.delta_border(h, w)
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
- self.split_input_params["clip_max_weight"], )
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
-
- if self.split_input_params["tie_braker"]:
- L_weighting = self.delta_border(Ly, Lx)
- L_weighting = torch.clip(L_weighting,
- self.split_input_params["clip_min_tie_weight"],
- self.split_input_params["clip_max_tie_weight"])
-
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
- weighting = weighting * L_weighting
- return weighting
-
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
- """
- :param x: img of size (bs, c, h, w)
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
- """
- bs, nc, h, w = x.shape
-
- # number of crops in image
- Ly = (h - kernel_size[0]) // stride[0] + 1
- Lx = (w - kernel_size[1]) // stride[1] + 1
-
- if uf == 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
-
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
-
- elif uf > 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
- dilation=1, padding=0,
- stride=(stride[0] * uf, stride[1] * uf))
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
-
- elif df > 1 and uf == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
- dilation=1, padding=0,
- stride=(stride[0] // df, stride[1] // df))
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
-
- else:
- raise NotImplementedError
-
- return fold, unfold, normalization, weighting
-
- @torch.no_grad()
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
- cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
- x = super().get_input(batch, k)
- if bs is not None:
- x = x[:bs]
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- cond_key = cond_key or self.cond_stage_key
- xc = super().get_input(batch, cond_key)
- if bs is not None:
- xc["c_crossattn"] = xc["c_crossattn"][:bs]
- xc["c_concat"] = xc["c_concat"][:bs]
- cond = {}
-
- # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
- random = torch.rand(x.size(0), device=x.device)
- prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
- input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
-
- null_prompt = self.get_learned_conditioning([""])
- cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
- cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
-
- out = [z, cond]
- if return_first_stage_outputs:
- xrec = self.decode_first_stage(z)
- out.extend([x, xrec])
- if return_original_cond:
- out.append(xc)
- return out
-
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- # same as above but without decorator
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- @torch.no_grad()
- def encode_first_stage(self, x):
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- df = self.split_input_params["vqf"]
- self.split_input_params['original_image_size'] = x.shape[-2:]
- bs, nc, h, w = x.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
- z = unfold(x) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
-
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization
- return decoded
-
- else:
- return self.first_stage_model.encode(x)
- else:
- return self.first_stage_model.encode(x)
-
- def shared_step(self, batch, **kwargs):
- x, c = self.get_input(batch, self.first_stage_key)
- loss = self(x, c)
- return loss
-
- def forward(self, x, c, *args, **kwargs):
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- if self.model.conditioning_key is not None:
- assert c is not None
- if self.cond_stage_trainable:
- c = self.get_learned_conditioning(c)
- if self.shorten_cond_schedule: # TODO: drop this option
- tc = self.cond_ids[t].to(self.device)
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
- return self.p_losses(x, c, t, *args, **kwargs)
-
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
- def apply_model(self, x_noisy, t, cond, return_ids=False):
-
- if isinstance(cond, dict):
- # hybrid case, cond is exptected to be a dict
- pass
- else:
- if not isinstance(cond, list):
- cond = [cond]
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
- cond = {key: cond}
-
- if hasattr(self, "split_input_params"):
- assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
-
- h, w = x_noisy.shape[-2:]
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
-
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
-
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
- c_key = next(iter(cond.keys())) # get key
- c = next(iter(cond.values())) # get value
- assert (len(c) == 1) # todo extend to list with more than one elem
- c = c[0] # get element
-
- c = unfold(c)
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
-
- elif self.cond_stage_key == 'coordinates_bbox':
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
-
- # assuming padding of unfold is always 0 and its dilation is always 1
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
- full_img_h, full_img_w = self.split_input_params['original_image_size']
- # as we are operating on latents, we need the factor from the original image size to the
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
- rescale_latent = 2 ** (num_downs)
-
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
- # need to rescale the tl patch coordinates to be in between (0,1)
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
- for patch_nr in range(z.shape[-1])]
-
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
- patch_limits = [(x_tl, y_tl,
- rescale_latent * ks[0] / full_img_w,
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
-
- # tokenize crop coordinates for the bounding boxes of the respective patches
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
- print(patch_limits_tknzd[0].shape)
- # cut tknzd crop position from conditioning
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
- print(cut_cond.shape)
-
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
- print(adapted_cond.shape)
- adapted_cond = self.get_learned_conditioning(adapted_cond)
- print(adapted_cond.shape)
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
- print(adapted_cond.shape)
-
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
-
- else:
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
-
- # apply model by loop over crops
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
- assert not isinstance(output_list[0],
- tuple) # todo cant deal with multiple model outputs check this never happens
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- x_recon = fold(o) / normalization
-
- else:
- x_recon = self.model(x_noisy, t, **cond)
-
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
-
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
-
- def _prior_bpd(self, x_start):
- """
- Get the prior KL term for the variational lower-bound, measured in
- bits-per-dim.
- This term can't be optimized, as it only depends on the encoder.
- :param x_start: the [N x C x ...] tensor of inputs.
- :return: a batch of [N] KL values (in bits), one per batch element.
- """
- batch_size = x_start.shape[0]
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
- return mean_flat(kl_prior) / np.log(2.0)
-
- def p_losses(self, x_start, cond, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_output = self.apply_model(x_noisy, t, cond)
-
- loss_dict = {}
- prefix = 'train' if self.training else 'val'
-
- if self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "eps":
- target = noise
- else:
- raise NotImplementedError()
-
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
-
- logvar_t = self.logvar[t].to(self.device)
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
- if self.learn_logvar:
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
- loss_dict.update({'logvar': self.logvar.data.mean()})
-
- loss = self.l_simple_weight * loss.mean()
-
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
- loss += (self.original_elbo_weight * loss_vlb)
- loss_dict.update({f'{prefix}/loss': loss})
-
- return loss, loss_dict
-
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
- return_x0=False, score_corrector=None, corrector_kwargs=None):
- t_in = t
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
-
- if score_corrector is not None:
- assert self.parameterization == "eps"
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
-
- if return_codebook_ids:
- model_out, logits = model_out
-
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- else:
- raise NotImplementedError()
-
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
- if quantize_denoised:
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- if return_codebook_ids:
- return model_mean, posterior_variance, posterior_log_variance, logits
- elif return_x0:
- return model_mean, posterior_variance, posterior_log_variance, x_recon
- else:
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
- b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
- return_codebook_ids=return_codebook_ids,
- quantize_denoised=quantize_denoised,
- return_x0=return_x0,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if return_codebook_ids:
- raise DeprecationWarning("Support dropped.")
- model_mean, _, model_log_variance, logits = outputs
- elif return_x0:
- model_mean, _, model_log_variance, x0 = outputs
- else:
- model_mean, _, model_log_variance = outputs
-
- noise = noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
-
- if return_codebook_ids:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
- if return_x0:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
- else:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
- log_every_t=None):
- if not log_every_t:
- log_every_t = self.log_every_t
- timesteps = self.num_timesteps
- if batch_size is not None:
- b = batch_size if batch_size is not None else shape[0]
- shape = [batch_size] + list(shape)
- else:
- b = batch_size = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=self.device)
- else:
- img = x_T
- intermediates = []
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
- total=timesteps) if verbose else reversed(
- range(0, timesteps))
- if type(temperature) == float:
- temperature = [temperature] * timesteps
-
- for i in iterator:
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img, x0_partial = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised, return_x0=True,
- temperature=temperature[i], noise_dropout=noise_dropout,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if mask is not None:
- assert x0 is not None
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_loop(self, cond, shape, return_intermediates=False,
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, start_T=None,
- log_every_t=None):
-
- if not log_every_t:
- log_every_t = self.log_every_t
- device = self.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- intermediates = [img]
- if timesteps is None:
- timesteps = self.num_timesteps
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
- range(0, timesteps))
-
- if mask is not None:
- assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
-
- for i in iterator:
- ts = torch.full((b,), i, device=device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised)
- if mask is not None:
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
-
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
- verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None,**kwargs):
- if shape is None:
- shape = (batch_size, self.channels, self.image_size, self.image_size)
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
- return self.p_sample_loop(cond,
- shape,
- return_intermediates=return_intermediates, x_T=x_T,
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
- mask=mask, x0=x0)
-
- @torch.no_grad()
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
-
- if ddim:
- ddim_sampler = DDIMSampler(self)
- shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
- shape,cond,verbose=False,**kwargs)
-
- else:
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True,**kwargs)
-
- return samples, intermediates
-
-
- @torch.no_grad()
- def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
- plot_diffusion_rows=False, **kwargs):
-
- use_ddim = False
-
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
- return_first_stage_outputs=True,
- force_c_encode=True,
- return_original_cond=True,
- bs=N, uncond=0)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reals"] = xc["c_concat"]
- log["reconstruction"] = xrec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
- log["conditioning"] = xc
- elif self.cond_stage_key == 'class_label':
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
- log['conditioning'] = xc
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
- self.first_stage_model, IdentityFirstStage):
- # also display when quantizing x0 while sampling
- with self.ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta,
- quantize_denoised=True)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
- # quantize_denoised=True)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_x0_quantized"] = x_samples
-
- if inpaint:
- # make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
- mask = torch.ones(N, h, w).to(self.device)
- # zeros will be filled in
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
- mask = mask[:, None, ...]
- with self.ema_scope("Plotting Inpaint"):
-
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_inpainting"] = x_samples
- log["mask"] = mask
-
- # outpaint
- with self.ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_outpainting"] = x_samples
-
- if plot_progressive_rows:
- with self.ema_scope("Plotting Progressives"):
- img, progressives = self.progressive_denoising(c,
- shape=(self.channels, self.image_size, self.image_size),
- batch_size=N)
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
- log["progressive_row"] = prog_row
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.cond_stage_trainable:
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
- params = params + list(self.cond_stage_model.parameters())
- if self.learn_logvar:
- print('Diffusion model optimizing logvar')
- params.append(self.logvar)
- opt = torch.optim.AdamW(params, lr=lr)
- if self.use_scheduler:
- assert 'target' in self.scheduler_config
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [opt], scheduler
- return opt
-
- @torch.no_grad()
- def to_rgb(self, x):
- x = x.float()
- if not hasattr(self, "colorize"):
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
- x = nn.functional.conv2d(x, weight=self.colorize)
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
- return x
-
-
-class DiffusionWrapper(pl.LightningModule):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
-
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t)
- elif self.conditioning_key == 'concat':
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t)
- elif self.conditioning_key == 'crossattn':
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc)
- elif self.conditioning_key == 'hybrid':
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc)
- elif self.conditioning_key == 'adm':
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc)
- else:
- raise NotImplementedError()
-
- return out
-
-
-class Layout2ImgDiffusion(LatentDiffusion):
- # TODO: move all layout-specific hacks to this class
- def __init__(self, cond_stage_key, *args, **kwargs):
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
-
- def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
-
- key = 'train' if self.training else 'validation'
- dset = self.trainer.datamodule.datasets[key]
- mapper = dset.conditional_builders[self.cond_stage_key]
-
- bbox_imgs = []
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
- bbox_imgs.append(bboximg)
-
- cond_img = torch.stack(bbox_imgs, dim=0)
- logs['bbox_image'] = cond_img
- return logs
diff --git a/modules/ngrok.py b/modules/ngrok.py
deleted file mode 100644
index 3df2c06bf1f10d49b7e9397758bc4f3661a51ba7..0000000000000000000000000000000000000000
--- a/modules/ngrok.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from pyngrok import ngrok, conf, exception
-
-def connect(token, port, region):
- account = None
- if token is None:
- token = 'None'
- else:
- if ':' in token:
- # token = authtoken:username:password
- account = token.split(':')[1] + ':' + token.split(':')[-1]
- token = token.split(':')[0]
-
- config = conf.PyngrokConfig(
- auth_token=token, region=region
- )
- try:
- if account is None:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
- else:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
- except exception.PyngrokNgrokError:
- print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
- f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
- else:
- print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
- 'You can use this link after the launch is complete.')
diff --git a/modules/paths.py b/modules/paths.py
deleted file mode 100644
index 5883788608bafd3ae46b90aff15881c1b3e69d03..0000000000000000000000000000000000000000
--- a/modules/paths.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import argparse
-import os
-import sys
-import modules.safe
-
-script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
-
-# Parse the --data-dir flag first so we can use it as a base for our other argument default values
-parser = argparse.ArgumentParser(add_help=False)
-parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
-cmd_opts_pre = parser.parse_known_args()[0]
-data_path = cmd_opts_pre.data_dir
-models_path = os.path.join(data_path, "models")
-
-# data_path = cmd_opts_pre.data
-sys.path.insert(0, script_path)
-
-# search for directory of stable diffusion in following places
-sd_path = None
-possible_sd_paths = [os.path.join(script_path, '/content/gdrive/MyDrive/sd/stablediffusion'), '.', os.path.dirname(script_path)]
-for possible_sd_path in possible_sd_paths:
- if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
- sd_path = os.path.abspath(possible_sd_path)
- break
-
-assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
-
-path_dirs = [
- (sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, 'src/taming-transformers'), 'taming', 'Taming Transformers', []),
- (os.path.join(sd_path, 'src/codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
- (os.path.join(sd_path, 'src/blip'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, 'src/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
-]
-
-paths = {}
-
-for d, must_exist, what, options in path_dirs:
- must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
- if not os.path.exists(must_exist_path):
- print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
- else:
- d = os.path.abspath(d)
- if "atstart" in options:
- sys.path.insert(0, d)
- else:
- sys.path.append(d)
- paths[what] = d
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
deleted file mode 100644
index 09d8e6056cf49e731d312bcc4b19ac5564884bfa..0000000000000000000000000000000000000000
--- a/modules/postprocessing.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import os
-
-from PIL import Image
-
-from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
-from modules.shared import opts
-
-
-def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
- devices.torch_gc()
-
- shared.state.begin()
- shared.state.job = 'extras'
-
- image_data = []
- image_names = []
- outputs = []
-
- if extras_mode == 1:
- for img in image_folder:
- image = Image.open(img)
- image_data.append(image)
- image_names.append(os.path.splitext(img.orig_name)[0])
- elif extras_mode == 2:
- assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
- assert input_dir, 'input directory not selected'
-
- image_list = shared.listfiles(input_dir)
- for filename in image_list:
- try:
- image = Image.open(filename)
- except Exception:
- continue
- image_data.append(image)
- image_names.append(filename)
- else:
- assert image, 'image not selected'
-
- image_data.append(image)
- image_names.append(None)
-
- if extras_mode == 2 and output_dir != '':
- outpath = output_dir
- else:
- outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- infotext = ''
-
- for image, name in zip(image_data, image_names):
- shared.state.textinfo = name
-
- existing_pnginfo = image.info or {}
-
- pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
-
- scripts.scripts_postproc.run(pp, args)
-
- if opts.use_original_name_batch and name is not None:
- basename = os.path.splitext(os.path.basename(name))[0]
- else:
- basename = ''
-
- infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
-
- if opts.enable_pnginfo:
- pp.image.info = existing_pnginfo
- pp.image.info["postprocessing"] = infotext
-
- if save_output:
- images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
-
- if extras_mode != 2 or show_extras_results:
- outputs.append(pp.image)
-
- devices.torch_gc()
-
- return outputs, ui_common.plaintext_to_html(infotext), ''
-
-
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
- """old handler for API"""
-
- args = scripts.scripts_postproc.create_args_for_run({
- "Upscale": {
- "upscale_mode": resize_mode,
- "upscale_by": upscaling_resize,
- "upscale_to_width": upscaling_resize_w,
- "upscale_to_height": upscaling_resize_h,
- "upscale_crop": upscaling_crop,
- "upscaler_1_name": extras_upscaler_1,
- "upscaler_2_name": extras_upscaler_2,
- "upscaler_2_visibility": extras_upscaler_2_visibility,
- },
- "GFPGAN": {
- "gfpgan_visibility": gfpgan_visibility,
- },
- "CodeFormer": {
- "codeformer_visibility": codeformer_visibility,
- "codeformer_weight": codeformer_weight,
- },
- })
-
- return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
diff --git a/modules/processing.py b/modules/processing.py
deleted file mode 100644
index 2009d3bf8167dfdee3602c0a948549b8d8b84857..0000000000000000000000000000000000000000
--- a/modules/processing.py
+++ /dev/null
@@ -1,1056 +0,0 @@
-import json
-import math
-import os
-import sys
-import warnings
-
-import torch
-import numpy as np
-from PIL import Image, ImageFilter, ImageOps
-import random
-import cv2
-from skimage import exposure
-from typing import Any, Dict, List, Optional
-
-import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
-from modules.sd_hijack import model_hijack
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-import modules.paths as paths
-import modules.face_restoration
-import modules.images as images
-import modules.styles
-import modules.sd_models as sd_models
-import modules.sd_vae as sd_vae
-import logging
-from ldm.data.util import AddMiDaS
-from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
-
-from einops import repeat, rearrange
-from blendmodes.blend import blendLayers, BlendType
-
-# some of those options should not be changed at all because they would break the model, so I removed them from options.
-opt_C = 4
-opt_f = 8
-
-
-def setup_color_correction(image):
- logging.info("Calibrating color correction.")
- correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
- return correction_target
-
-
-def apply_color_correction(correction, original_image):
- logging.info("Applying color correction.")
- image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
- cv2.cvtColor(
- np.asarray(original_image),
- cv2.COLOR_RGB2LAB
- ),
- correction,
- channel_axis=2
- ), cv2.COLOR_LAB2RGB).astype("uint8"))
-
- image = blendLayers(image, original_image, BlendType.LUMINOSITY)
-
- return image
-
-
-def apply_overlay(image, paste_loc, index, overlays):
- if overlays is None or index >= len(overlays):
- return image
-
- overlay = overlays[index]
-
- if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
-
- return image
-
-
-def txt2img_image_conditioning(sd_model, x, width, height):
- if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
-
-
-class StableDiffusionProcessing:
- """
- The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
- """
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
- if sampler_index is not None:
- print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
-
- self.outpath_samples: str = outpath_samples
- self.outpath_grids: str = outpath_grids
- self.prompt: str = prompt
- self.prompt_for_display: str = None
- self.negative_prompt: str = (negative_prompt or "")
- self.styles: list = styles or []
- self.seed: int = seed
- self.subseed: int = subseed
- self.subseed_strength: float = subseed_strength
- self.seed_resize_from_h: int = seed_resize_from_h
- self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_name: str = sampler_name
- self.batch_size: int = batch_size
- self.n_iter: int = n_iter
- self.steps: int = steps
- self.cfg_scale: float = cfg_scale
- self.width: int = width
- self.height: int = height
- self.restore_faces: bool = restore_faces
- self.tiling: bool = tiling
- self.do_not_save_samples: bool = do_not_save_samples
- self.do_not_save_grid: bool = do_not_save_grid
- self.extra_generation_params: dict = extra_generation_params or {}
- self.overlay_images = overlay_images
- self.eta = eta
- self.do_not_reload_embeddings = do_not_reload_embeddings
- self.paste_to = None
- self.color_corrections = None
- self.denoising_strength: float = denoising_strength
- self.sampler_noise_scheduler_override = None
- self.ddim_discretize = ddim_discretize or opts.ddim_discretize
- self.s_churn = s_churn or opts.s_churn
- self.s_tmin = s_tmin or opts.s_tmin
- self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
- self.s_noise = s_noise or opts.s_noise
- self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
- self.override_settings_restore_afterwards = override_settings_restore_afterwards
- self.is_using_inpainting_conditioning = False
- self.disable_extra_networks = False
-
- if not seed_enable_extras:
- self.subseed = -1
- self.subseed_strength = 0
- self.seed_resize_from_h = 0
- self.seed_resize_from_w = 0
-
- self.scripts = None
- self.script_args = script_args
- self.all_prompts = None
- self.all_negative_prompts = None
- self.all_seeds = None
- self.all_subseeds = None
- self.iteration = 0
-
- @property
- def sd_model(self):
- return shared.sd_model
-
- def txt2img_image_conditioning(self, x, width=None, height=None):
- self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
-
- return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
-
- def depth2img_image_conditioning(self, source_image):
- # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
- transformer = AddMiDaS(model_type="dpt_hybrid")
- transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
- midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
- midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
-
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
- conditioning = torch.nn.functional.interpolate(
- self.sd_model.depth_model(midas_in),
- size=conditioning_image.shape[2:],
- mode="bicubic",
- align_corners=False,
- )
-
- (depth_min, depth_max) = torch.aminmax(conditioning)
- conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
- return conditioning
-
- def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
-
- return conditioning_image
-
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
- self.is_using_inpainting_conditioning = True
-
- # Handle the different mask inputs
- if image_mask is not None:
- if torch.is_tensor(image_mask):
- conditioning_mask = image_mask
- else:
- conditioning_mask = np.array(image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
- conditioning_image = torch.lerp(
- source_image,
- source_image * (1.0 - conditioning_mask),
- getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
- )
-
- # Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
-
- return image_conditioning
-
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
- source_image = devices.cond_cast_float(source_image)
-
- # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
- # identify itself with a field common to all models. The conditioning_key is also hybrid.
- if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
- return self.depth2img_image_conditioning(source_image)
-
- if self.sd_model.cond_stage_key == "edit":
- return self.edit_image_conditioning(source_image)
-
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
-
- # Dummy zero conditioning if we're not using inpainting or depth model.
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- pass
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- raise NotImplementedError()
-
- def close(self):
- self.sampler = None
-
-
-class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
- self.images = images_list
- self.prompt = p.prompt
- self.negative_prompt = p.negative_prompt
- self.seed = seed
- self.subseed = subseed
- self.subseed_strength = p.subseed_strength
- self.info = info
- self.comments = comments
- self.width = p.width
- self.height = p.height
- self.sampler_name = p.sampler_name
- self.cfg_scale = p.cfg_scale
- self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.steps = p.steps
- self.batch_size = p.batch_size
- self.restore_faces = p.restore_faces
- self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
- self.sd_model_hash = shared.sd_model.sd_model_hash
- self.seed_resize_from_w = p.seed_resize_from_w
- self.seed_resize_from_h = p.seed_resize_from_h
- self.denoising_strength = getattr(p, 'denoising_strength', None)
- self.extra_generation_params = p.extra_generation_params
- self.index_of_first_image = index_of_first_image
- self.styles = p.styles
- self.job_timestamp = state.job_timestamp
- self.clip_skip = opts.CLIP_stop_at_last_layers
-
- self.eta = p.eta
- self.ddim_discretize = p.ddim_discretize
- self.s_churn = p.s_churn
- self.s_tmin = p.s_tmin
- self.s_tmax = p.s_tmax
- self.s_noise = p.s_noise
- self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
- self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
- self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
- self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
-
- self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
- self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
- self.all_seeds = all_seeds or p.all_seeds or [self.seed]
- self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
- self.infotexts = infotexts or [info]
-
- def js(self):
- obj = {
- "prompt": self.all_prompts[0],
- "all_prompts": self.all_prompts,
- "negative_prompt": self.all_negative_prompts[0],
- "all_negative_prompts": self.all_negative_prompts,
- "seed": self.seed,
- "all_seeds": self.all_seeds,
- "subseed": self.subseed,
- "all_subseeds": self.all_subseeds,
- "subseed_strength": self.subseed_strength,
- "width": self.width,
- "height": self.height,
- "sampler_name": self.sampler_name,
- "cfg_scale": self.cfg_scale,
- "steps": self.steps,
- "batch_size": self.batch_size,
- "restore_faces": self.restore_faces,
- "face_restoration_model": self.face_restoration_model,
- "sd_model_hash": self.sd_model_hash,
- "seed_resize_from_w": self.seed_resize_from_w,
- "seed_resize_from_h": self.seed_resize_from_h,
- "denoising_strength": self.denoising_strength,
- "extra_generation_params": self.extra_generation_params,
- "index_of_first_image": self.index_of_first_image,
- "infotexts": self.infotexts,
- "styles": self.styles,
- "job_timestamp": self.job_timestamp,
- "clip_skip": self.clip_skip,
- "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
- }
-
- return json.dumps(obj)
-
- def infotext(self, p: StableDiffusionProcessing, index):
- return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
-
-
-# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
-def slerp(val, low, high):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
-
- if dot.mean() > 0.9995:
- return low * val + high * (1 - val)
-
- omega = torch.acos(dot)
- so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
- return res
-
-
-def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
- eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
- xs = []
-
- # if we have multiple seeds, this means we are working with batch size>1; this then
- # enables the generation of additional tensors with noise that the sampler will use during its processing.
- # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
- # produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
- sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
- else:
- sampler_noises = None
-
- for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
-
- subnoise = None
- if subseeds is not None:
- subseed = 0 if i >= len(subseeds) else subseeds[i]
-
- subnoise = devices.randn(subseed, noise_shape)
-
- # randn results depend on device; gpu and cpu get different results for same seed;
- # the way I see it, it's better to do this on CPU, so that everyone gets same result;
- # but the original script had it like this, so I do not dare change it for now because
- # it will break everyone's seeds.
- noise = devices.randn(seed, noise_shape)
-
- if subnoise is not None:
- noise = slerp(subseed_strength, noise, subnoise)
-
- if noise_shape != shape:
- x = devices.randn(seed, shape)
- dx = (shape[2] - noise_shape[2]) // 2
- dy = (shape[1] - noise_shape[1]) // 2
- w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
- h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
- tx = 0 if dx < 0 else dx
- ty = 0 if dy < 0 else dy
- dx = max(-dx, 0)
- dy = max(-dy, 0)
-
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
- noise = x
-
- if sampler_noises is not None:
- cnt = p.sampler.number_of_needed_noises(p)
-
- if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
-
- for j in range(cnt):
- sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
-
- xs.append(noise)
-
- if sampler_noises is not None:
- p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
-
- x = torch.stack(xs).to(shared.device)
- return x
-
-
-def decode_first_stage(model, x):
- with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
-
- return x
-
-
-def get_fixed_seed(seed):
- if seed is None or seed == '' or seed == -1:
- return int(random.randrange(4294967294))
-
- return seed
-
-
-def fix_seed(p):
- p.seed = get_fixed_seed(p.seed)
- p.subseed = get_fixed_seed(p.subseed)
-
-
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
- index = position_in_batch + iteration * p.batch_size
-
- clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
-
- generation_params = {
- "Steps": p.steps,
- "Sampler": p.sampler_name,
- "CFG scale": p.cfg_scale,
- "Image CFG scale": getattr(p, 'image_cfg_scale', None),
- "Seed": all_seeds[index],
- "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
- "Size": f"{p.width}x{p.height}",
- "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
- "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
- "Denoising strength": getattr(p, 'denoising_strength', None),
- "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
- "Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
- }
-
- generation_params.update(p.extra_generation_params)
-
- generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
-
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
-
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
-
-
-def process_images(p: StableDiffusionProcessing) -> Processed:
- stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
-
- try:
- for k, v in p.override_settings.items():
- setattr(opts, k, v)
-
- if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights()
-
- if k == 'sd_vae':
- sd_vae.reload_vae_weights()
-
- res = process_images_inner(p)
-
- finally:
- # restore opts to original state
- if p.override_settings_restore_afterwards:
- for k, v in stored_opts.items():
- setattr(opts, k, v)
- if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights()
-
- if k == 'sd_vae':
- sd_vae.reload_vae_weights()
-
- return res
-
-
-def process_images_inner(p: StableDiffusionProcessing) -> Processed:
- """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
-
- if type(p.prompt) == list:
- assert(len(p.prompt) > 0)
- else:
- assert p.prompt is not None
-
- devices.torch_gc()
-
- seed = get_fixed_seed(p.seed)
- subseed = get_fixed_seed(p.subseed)
-
- modules.sd_hijack.model_hijack.apply_circular(p.tiling)
- modules.sd_hijack.model_hijack.clear_comments()
-
- comments = {}
-
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
-
- if type(seed) == list:
- p.all_seeds = seed
- else:
- p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
-
- if type(subseed) == list:
- p.all_subseeds = subseed
- else:
- p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
-
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
-
- if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
- model_hijack.embedding_db.load_textual_inversion_embeddings()
-
- if p.scripts is not None:
- p.scripts.process(p)
-
- infotexts = []
- output_images = []
-
- cached_uc = [None, None]
- cached_c = [None, None]
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- """
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
- using a cache to store the result if the same arguments have been used before.
-
- cache is an array containing two elements. The first element is a tuple
- representing the previously used arguments, or None if no arguments
- have been used before. The second element is where the previously
- computed result is stored.
- """
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
-
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
-
- cache[0] = (required_prompts, steps)
- return cache[1]
-
- with torch.no_grad(), p.sd_model.ema_scope():
- with devices.autocast():
- p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
-
- # for OSX, loading the model during sampling changes the generated picture, so it is loaded here
- if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
- sd_vae_approx.model()
-
- if state.job_count == -1:
- state.job_count = p.n_iter
-
- for n in range(p.n_iter):
- p.iteration = n
-
- if state.skipped:
- state.skipped = False
-
- if state.interrupted:
- break
-
- prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
-
- if len(prompts) == 0:
- break
-
- prompts, extra_network_data = extra_networks.parse_prompts(prompts)
-
- if not p.disable_extra_networks:
- with devices.autocast():
- extra_networks.activate(p, extra_network_data)
-
- if p.scripts is not None:
- p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
-
- # params.txt should be saved after scripts.process_batch, since the
- # infotext could be modified by that callback
- # Example: a wildcard processed by process_batch sets an extra model
- # strength, which is saved as "Model Strength: 1.0" in the infotext
- if n == 0:
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
-
- if len(model_hijack.comments) > 0:
- for comment in model_hijack.comments:
- comments[comment] = 1
-
- if p.n_iter > 1:
- shared.state.job = f"Batch {n+1} out of {p.n_iter}"
-
- with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
-
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
-
- x_samples_ddim = torch.stack(x_samples_ddim).float()
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
-
- del samples_ddim
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
-
- devices.torch_gc()
-
- if p.scripts is not None:
- p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
-
- for i, x_sample in enumerate(x_samples_ddim):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
-
- if p.restore_faces:
- if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
-
- devices.torch_gc()
-
- x_sample = modules.face_restoration.restore_faces(x_sample)
- devices.torch_gc()
-
- image = Image.fromarray(x_sample)
-
- if p.scripts is not None:
- pp = scripts.PostprocessImageArgs(image)
- p.scripts.postprocess_image(p, pp)
- image = pp.image
-
- if p.color_corrections is not None and i < len(p.color_corrections):
- if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
- image = apply_color_correction(p.color_corrections[i], image)
-
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
-
- if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
-
- text = infotext(n, i)
- infotexts.append(text)
- if opts.enable_pnginfo:
- image.info["parameters"] = text
- output_images.append(image)
-
- del x_samples_ddim
-
- devices.torch_gc()
-
- state.nextjob()
-
- p.color_corrections = None
-
- index_of_first_image = 0
- unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
- if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
- grid = images.image_grid(output_images, p.batch_size)
-
- if opts.return_grid:
- text = infotext()
- infotexts.insert(0, text)
- if opts.enable_pnginfo:
- grid.info["parameters"] = text
- output_images.insert(0, grid)
- index_of_first_image = 1
-
- if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
-
- if not p.disable_extra_networks:
- extra_networks.deactivate(p, extra_network_data)
-
- devices.torch_gc()
-
- res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
-
- if p.scripts is not None:
- p.scripts.postprocess(p, res)
-
- return res
-
-
-def old_hires_fix_first_pass_dimensions(width, height):
- """old algorithm for auto-calculating first pass size"""
-
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- width = math.ceil(scale * width / 64) * 64
- height = math.ceil(scale * height / 64) * 64
-
- return width, height
-
-
-class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
- sampler = None
-
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
- super().__init__(**kwargs)
- self.enable_hr = enable_hr
- self.denoising_strength = denoising_strength
- self.hr_scale = hr_scale
- self.hr_upscaler = hr_upscaler
- self.hr_second_pass_steps = hr_second_pass_steps
- self.hr_resize_x = hr_resize_x
- self.hr_resize_y = hr_resize_y
- self.hr_upscale_to_x = hr_resize_x
- self.hr_upscale_to_y = hr_resize_y
-
- if firstphase_width != 0 or firstphase_height != 0:
- self.hr_upscale_to_x = self.width
- self.hr_upscale_to_y = self.height
- self.width = firstphase_width
- self.height = firstphase_height
-
- self.truncate_x = 0
- self.truncate_y = 0
- self.applied_old_hires_behavior_to = None
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- if self.enable_hr:
- if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
- self.hr_resize_x = self.width
- self.hr_resize_y = self.height
- self.hr_upscale_to_x = self.width
- self.hr_upscale_to_y = self.height
-
- self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
- self.applied_old_hires_behavior_to = (self.width, self.height)
-
- if self.hr_resize_x == 0 and self.hr_resize_y == 0:
- self.extra_generation_params["Hires upscale"] = self.hr_scale
- self.hr_upscale_to_x = int(self.width * self.hr_scale)
- self.hr_upscale_to_y = int(self.height * self.hr_scale)
- else:
- self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
-
- if self.hr_resize_y == 0:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- elif self.hr_resize_x == 0:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
- else:
- target_w = self.hr_resize_x
- target_h = self.hr_resize_y
- src_ratio = self.width / self.height
- dst_ratio = self.hr_resize_x / self.hr_resize_y
-
- if src_ratio < dst_ratio:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- else:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
-
- self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
- self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
-
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
-
- if not state.processing_has_refined_job_count:
- if state.job_count == -1:
- state.job_count = self.n_iter
-
- shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
- state.job_count = state.job_count * 2
- state.processing_has_refined_job_count = True
-
- if self.hr_second_pass_steps:
- self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
-
- if self.hr_upscaler is not None:
- self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
-
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
-
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
-
- if not self.enable_hr:
- return samples
-
- target_width = self.hr_upscale_to_x
- target_height = self.hr_upscale_to_y
-
- def save_intermediate(image, index):
- """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
-
- if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
- return
-
- if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index, approximation=0)
-
- info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
-
- if latent_scale_mode is not None:
- for i in range(samples.shape[0]):
- save_intermediate(samples, i)
-
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
-
- # Avoid making the inpainting conditioning unless necessary as
- # this does need some extra compute to decode / encode the image again.
- if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
- image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
- else:
- image_conditioning = self.txt2img_image_conditioning(samples)
- else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
-
- save_intermediate(image, i)
-
- image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
-
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
-
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
-
- image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
-
- shared.state.nextjob()
-
- img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
-
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
-
- # GC now before running the next img2img to prevent running out of memory
- x = None
- devices.torch_gc()
-
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
-
- return samples
-
-
-class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
- sampler = None
-
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
- super().__init__(**kwargs)
-
- self.init_images = init_images
- self.resize_mode: int = resize_mode
- self.denoising_strength: float = denoising_strength
- self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
- self.init_latent = None
- self.image_mask = mask
- self.latent_mask = None
- self.mask_for_overlay = None
- self.mask_blur = mask_blur
- self.inpainting_fill = inpainting_fill
- self.inpaint_full_res = inpaint_full_res
- self.inpaint_full_res_padding = inpaint_full_res_padding
- self.inpainting_mask_invert = inpainting_mask_invert
- self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
- self.mask = None
- self.nmask = None
- self.image_conditioning = None
-
- def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- crop_region = None
-
- image_mask = self.image_mask
-
- if image_mask is not None:
- image_mask = image_mask.convert('L')
-
- if self.inpainting_mask_invert:
- image_mask = ImageOps.invert(image_mask)
-
- if self.mask_blur > 0:
- image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
-
- if self.inpaint_full_res:
- self.mask_for_overlay = image_mask
- mask = image_mask.convert('L')
- crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
- crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
- x1, y1, x2, y2 = crop_region
-
- mask = mask.crop(crop_region)
- image_mask = images.resize_image(2, mask, self.width, self.height)
- self.paste_to = (x1, y1, x2-x1, y2-y1)
- else:
- image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
-
- self.overlay_images = []
-
- latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
-
- add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
- if add_color_corrections:
- self.color_corrections = []
- imgs = []
- for img in self.init_images:
- image = images.flatten(img, opts.img2img_background_color)
-
- if crop_region is None and self.resize_mode != 3:
- image = images.resize_image(self.resize_mode, image, self.width, self.height)
-
- if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
-
- self.overlay_images.append(image_masked.convert('RGBA'))
-
- # crop_region is not None if we are doing inpaint full res
- if crop_region is not None:
- image = image.crop(crop_region)
- image = images.resize_image(2, image, self.width, self.height)
-
- if image_mask is not None:
- if self.inpainting_fill != 1:
- image = masking.fill(image, latent_mask)
-
- if add_color_corrections:
- self.color_corrections.append(setup_color_correction(image))
-
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
-
- imgs.append(image)
-
- if len(imgs) == 1:
- batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
- if self.overlay_images is not None:
- self.overlay_images = self.overlay_images * self.batch_size
-
- if self.color_corrections is not None and len(self.color_corrections) == 1:
- self.color_corrections = self.color_corrections * self.batch_size
-
- elif len(imgs) <= self.batch_size:
- self.batch_size = len(imgs)
- batch_images = np.array(imgs)
- else:
- raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
-
- image = torch.from_numpy(batch_images)
- image = 2. * image - 1.
- image = image.to(shared.device)
-
- self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
-
- if self.resize_mode == 3:
- self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
-
- if image_mask is not None:
- init_mask = latent_mask
- latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
- latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
- latmask = latmask[0]
- latmask = np.around(latmask)
- latmask = np.tile(latmask[None], (4, 1, 1))
-
- self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
- self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
-
- # this needs to be fixed to be done in sample() using actual seeds for batches
- if self.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
- elif self.inpainting_fill == 3:
- self.init_latent = self.init_latent * self.mask
-
- self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-
- if self.initial_noise_multiplier != 1.0:
- self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
- x *= self.initial_noise_multiplier
-
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
-
- if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
-
- del x
- devices.torch_gc()
-
- return samples
diff --git a/modules/progress.py b/modules/progress.py
deleted file mode 100644
index c69ecf3d1bce60a4dbc189defcb70b2532e47404..0000000000000000000000000000000000000000
--- a/modules/progress.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import base64
-import io
-import time
-
-import gradio as gr
-from pydantic import BaseModel, Field
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-current_task = None
-pending_tasks = {}
-finished_tasks = []
-
-
-def start_task(id_task):
- global current_task
-
- current_task = id_task
- pending_tasks.pop(id_task, None)
-
-
-def finish_task(id_task):
- global current_task
-
- if current_task == id_task:
- current_task = None
-
- finished_tasks.append(id_task)
- if len(finished_tasks) > 16:
- finished_tasks.pop(0)
-
-
-def add_task_to_queue(id_job):
- pending_tasks[id_job] = time.time()
-
-
-class ProgressRequest(BaseModel):
- id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
- id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
-
-
-class ProgressResponse(BaseModel):
- active: bool = Field(title="Whether the task is being worked on right now")
- queued: bool = Field(title="Whether the task is in queue")
- completed: bool = Field(title="Whether the task has already finished")
- progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
- eta: float = Field(default=None, title="ETA in secs")
- live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
- id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
- textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
-
-
-def setup_progress_api(app):
- return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
-
-
-def progressapi(req: ProgressRequest):
- active = req.id_task == current_task
- queued = req.id_task in pending_tasks
- completed = req.id_task in finished_tasks
-
- if not active:
- return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
-
- progress = 0
-
- job_count, job_no = shared.state.job_count, shared.state.job_no
- sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
-
- if job_count > 0:
- progress += job_no / job_count
- if sampling_steps > 0 and job_count > 0:
- progress += 1 / job_count * sampling_step / sampling_steps
-
- progress = min(progress, 1)
-
- elapsed_since_start = time.time() - shared.state.time_start
- predicted_duration = elapsed_since_start / progress if progress > 0 else None
- eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
-
- id_live_preview = req.id_live_preview
- shared.state.set_current_image()
- if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
- image = shared.state.current_image
- if image is not None:
- buffered = io.BytesIO()
- image.save(buffered, format="png")
- live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
- id_live_preview = shared.state.id_live_preview
- else:
- live_preview = None
- else:
- live_preview = None
-
- return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
-
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
deleted file mode 100644
index 696653725fbe0c46187cabb3ba50082a43b78a54..0000000000000000000000000000000000000000
--- a/modules/prompt_parser.py
+++ /dev/null
@@ -1,373 +0,0 @@
-import re
-from collections import namedtuple
-from typing import List
-import lark
-
-# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
-# will be represented with prompt_schedule like this (assuming steps=100):
-# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
-# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
-# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
-# [75, 'fantasy landscape with a lake and an oak in background masterful']
-# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
-
-schedule_parser = lark.Lark(r"""
-!start: (prompt | /[][():]/+)*
-prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
-!emphasized: "(" prompt ")"
- | "(" prompt ":" prompt ")"
- | "[" prompt "]"
-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
-alternate: "[" prompt ("|" prompt)+ "]"
-WHITESPACE: /\s+/
-plain: /([^\\\[\]():|]|\\.)+/
-%import common.SIGNED_NUMBER -> NUMBER
-""")
-
-def get_learned_conditioning_prompt_schedules(prompts, steps):
- """
- >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
- >>> g("test")
- [[10, 'test']]
- >>> g("a [b:3]")
- [[3, 'a '], [10, 'a b']]
- >>> g("a [b: 3]")
- [[3, 'a '], [10, 'a b']]
- >>> g("a [[[b]]:2]")
- [[2, 'a '], [10, 'a [[b]]']]
- >>> g("[(a:2):3]")
- [[3, ''], [10, '(a:2)']]
- >>> g("a [b : c : 1] d")
- [[1, 'a b d'], [10, 'a c d']]
- >>> g("a[b:[c:d:2]:1]e")
- [[1, 'abe'], [2, 'ace'], [10, 'ade']]
- >>> g("a [unbalanced")
- [[10, 'a [unbalanced']]
- >>> g("a [b:.5] c")
- [[5, 'a c'], [10, 'a b c']]
- >>> g("a [{b|d{:.5] c") # not handling this right now
- [[5, 'a c'], [10, 'a {b|d{ c']]
- >>> g("((a][:b:c [d:3]")
- [[3, '((a][:b:c '], [10, '((a][:b:c d']]
- >>> g("[a|(b:1.1)]")
- [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
- """
-
- def collect_steps(steps, tree):
- l = [steps]
- class CollectSteps(lark.Visitor):
- def scheduled(self, tree):
- tree.children[-1] = float(tree.children[-1])
- if tree.children[-1] < 1:
- tree.children[-1] *= steps
- tree.children[-1] = min(steps, int(tree.children[-1]))
- l.append(tree.children[-1])
- def alternate(self, tree):
- l.extend(range(1, steps+1))
- CollectSteps().visit(tree)
- return sorted(set(l))
-
- def at_step(step, tree):
- class AtStep(lark.Transformer):
- def scheduled(self, args):
- before, after, _, when = args
- yield before or () if step <= when else after
- def alternate(self, args):
- yield next(args[(step - 1)%len(args)])
- def start(self, args):
- def flatten(x):
- if type(x) == str:
- yield x
- else:
- for gen in x:
- yield from flatten(gen)
- return ''.join(flatten(args))
- def plain(self, args):
- yield args[0].value
- def __default__(self, data, children, meta):
- for child in children:
- yield child
- return AtStep().transform(tree)
-
- def get_schedule(prompt):
- try:
- tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError as e:
- if 0:
- import traceback
- traceback.print_exc()
- return [[steps, prompt]]
- return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
-
- promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
- return [promptdict[prompt] for prompt in prompts]
-
-
-ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-
-
-def get_learned_conditioning(model, prompts, steps):
- """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
- and the sampling step at which this condition is to be replaced by the next one.
-
- Input:
- (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
-
- Output:
- [
- [
- ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
- ],
- [
- ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
- ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
- ]
- ]
- """
- res = []
-
- prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
- cache = {}
-
- for prompt, prompt_schedule in zip(prompts, prompt_schedules):
-
- cached = cache.get(prompt, None)
- if cached is not None:
- res.append(cached)
- continue
-
- texts = [x[1] for x in prompt_schedule]
- conds = model.get_learned_conditioning(texts)
-
- cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
-
- cache[prompt] = cond_schedule
- res.append(cond_schedule)
-
- return res
-
-
-re_AND = re.compile(r"\bAND\b")
-re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-
-def get_multicond_prompt_list(prompts):
- res_indexes = []
-
- prompt_flat_list = []
- prompt_indexes = {}
-
- for prompt in prompts:
- subprompts = re_AND.split(prompt)
-
- indexes = []
- for subprompt in subprompts:
- match = re_weight.search(subprompt)
-
- text, weight = match.groups() if match is not None else (subprompt, 1.0)
-
- weight = float(weight) if weight is not None else 1.0
-
- index = prompt_indexes.get(text, None)
- if index is None:
- index = len(prompt_flat_list)
- prompt_flat_list.append(text)
- prompt_indexes[text] = index
-
- indexes.append((index, weight))
-
- res_indexes.append(indexes)
-
- return res_indexes, prompt_flat_list, prompt_indexes
-
-
-class ComposableScheduledPromptConditioning:
- def __init__(self, schedules, weight=1.0):
- self.schedules: List[ScheduledPromptConditioning] = schedules
- self.weight: float = weight
-
-
-class MulticondLearnedConditioning:
- def __init__(self, shape, batch):
- self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
- self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
-
-def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
- """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
- For each prompt, the list is obtained by splitting the prompt using the AND separator.
-
- https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
- """
-
- res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
-
- learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
-
- res = []
- for indexes in res_indexes:
- res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
-
- return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
-
-
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
- param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
- for i, cond_schedule in enumerate(c):
- target_index = 0
- for current, (end_at, cond) in enumerate(cond_schedule):
- if current_step <= end_at:
- target_index = current
- break
- res[i] = cond_schedule[target_index].cond
-
- return res
-
-
-def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
- param = c.batch[0][0].schedules[0].cond
-
- tensors = []
- conds_list = []
-
- for batch_no, composable_prompts in enumerate(c.batch):
- conds_for_batch = []
-
- for cond_index, composable_prompt in enumerate(composable_prompts):
- target_index = 0
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
- if current_step <= end_at:
- target_index = current
- break
-
- conds_for_batch.append((len(tensors), composable_prompt.weight))
- tensors.append(composable_prompt.schedules[target_index].cond)
-
- conds_list.append(conds_for_batch)
-
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
-
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
-
-
-re_attention = re.compile(r"""
-\\\(|
-\\\)|
-\\\[|
-\\]|
-\\\\|
-\\|
-\(|
-\[|
-:([+-]?[.\d]+)\)|
-\)|
-]|
-[^\\()\[\]:]+|
-:
-""", re.X)
-
-re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
-
-def parse_prompt_attention(text):
- """
- Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
- Accepted tokens are:
- (abc) - increases attention to abc by a multiplier of 1.1
- (abc:3.12) - increases attention to abc by a multiplier of 3.12
- [abc] - decreases attention to abc by a multiplier of 1.1
- \( - literal character '('
- \[ - literal character '['
- \) - literal character ')'
- \] - literal character ']'
- \\ - literal character '\'
- anything else - just text
-
- >>> parse_prompt_attention('normal text')
- [['normal text', 1.0]]
- >>> parse_prompt_attention('an (important) word')
- [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
- >>> parse_prompt_attention('(unbalanced')
- [['unbalanced', 1.1]]
- >>> parse_prompt_attention('\(literal\]')
- [['(literal]', 1.0]]
- >>> parse_prompt_attention('(unnecessary)(parens)')
- [['unnecessaryparens', 1.1]]
- >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
- [['a ', 1.0],
- ['house', 1.5730000000000004],
- [' ', 1.1],
- ['on', 1.0],
- [' a ', 1.1],
- ['hill', 0.55],
- [', sun, ', 1.1],
- ['sky', 1.4641000000000006],
- ['.', 1.1]]
- """
-
- res = []
- round_brackets = []
- square_brackets = []
-
- round_bracket_multiplier = 1.1
- square_bracket_multiplier = 1 / 1.1
-
- def multiply_range(start_position, multiplier):
- for p in range(start_position, len(res)):
- res[p][1] *= multiplier
-
- for m in re_attention.finditer(text):
- text = m.group(0)
- weight = m.group(1)
-
- if text.startswith('\\'):
- res.append([text[1:], 1.0])
- elif text == '(':
- round_brackets.append(len(res))
- elif text == '[':
- square_brackets.append(len(res))
- elif weight is not None and len(round_brackets) > 0:
- multiply_range(round_brackets.pop(), float(weight))
- elif text == ')' and len(round_brackets) > 0:
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
- elif text == ']' and len(square_brackets) > 0:
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
- else:
- parts = re.split(re_break, text)
- for i, part in enumerate(parts):
- if i > 0:
- res.append(["BREAK", -1])
- res.append([part, 1.0])
-
- for pos in round_brackets:
- multiply_range(pos, round_bracket_multiplier)
-
- for pos in square_brackets:
- multiply_range(pos, square_bracket_multiplier)
-
- if len(res) == 0:
- res = [["", 1.0]]
-
- # merge runs of identical weights
- i = 0
- while i + 1 < len(res):
- if res[i][1] == res[i + 1][1]:
- res[i][0] += res[i + 1][0]
- res.pop(i + 1)
- else:
- i += 1
-
- return res
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
-else:
- import torch # doctest faster
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
deleted file mode 100644
index aad4a6298540b0545ae520b54af38b8b006acd19..0000000000000000000000000000000000000000
--- a/modules/realesrgan_model.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import os
-import sys
-import traceback
-
-import numpy as np
-from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
-from realesrgan import RealESRGANer
-
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
-
-
-class UpscalerRealESRGAN(Upscaler):
- def __init__(self, path):
- self.name = "RealESRGAN"
- self.user_path = path
- super().__init__()
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
- self.enable = True
- self.scalers = []
- scalers = self.load_models(path)
- for scaler in scalers:
- if scaler.name in opts.realesrgan_enabled_models:
- self.scalers.append(scaler)
-
- except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- self.enable = False
- self.scalers = []
-
- def do_upscale(self, img, path):
- if not self.enable:
- return img
-
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print("Unable to load RealESRGAN model: %s" % info.name)
- return img
-
- upsampler = RealESRGANer(
- scale=info.scale,
- model_path=info.local_data_path,
- model=info.model(),
- half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
- )
-
- upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
-
- image = Image.fromarray(upsampled)
- return image
-
- def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- return None
-
- def load_models(self, _):
- return get_realesrgan_models(self)
-
-
-def get_realesrgan_models(scaler):
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
- models = [
- UpscalerData(
- name="R-ESRGAN General 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN General WDN 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN AnimeVideo",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN 4x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 4x+ Anime6B",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 2x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- scale=2,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- ),
- ]
- return models
- except Exception as e:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/safe.py b/modules/safe.py
deleted file mode 100644
index 82d44be3163085f5d05051653fbc52bc3ba6e311..0000000000000000000000000000000000000000
--- a/modules/safe.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# this code is adapted from the script contributed by anon from /h/
-
-import io
-import pickle
-import collections
-import sys
-import traceback
-
-import torch
-import numpy
-import _codecs
-import zipfile
-import re
-
-
-# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
-TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
-
-
-def encode(*args):
- out = _codecs.encode(*args)
- return out
-
-
-class RestrictedUnpickler(pickle.Unpickler):
- extra_handler = None
-
- def persistent_load(self, saved_id):
- assert saved_id[0] == 'storage'
- return TypedStorage()
-
- def find_class(self, module, name):
- if self.extra_handler is not None:
- res = self.extra_handler(module, name)
- if res is not None:
- return res
-
- if module == 'collections' and name == 'OrderedDict':
- return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
- return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
- return getattr(torch, name)
- if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
- return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
- return getattr(numpy.core.multiarray, name)
- if module == 'numpy' and name in ['dtype', 'ndarray']:
- return getattr(numpy, name)
- if module == '_codecs' and name == 'encode':
- return encode
- if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
- import pytorch_lightning.callbacks
- return pytorch_lightning.callbacks.model_checkpoint
- if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
- import pytorch_lightning.callbacks.model_checkpoint
- return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
- if module == "__builtin__" and name == 'set':
- return set
-
- # Forbid everything else.
- raise Exception(f"global '{module}/{name}' is forbidden")
-
-
-# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/'
-allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
-data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
-
-def check_zip_filenames(filename, names):
- for name in names:
- if allowed_zip_names_re.match(name):
- continue
-
- raise Exception(f"bad file inside {filename}: {name}")
-
-
-def check_pt(filename, extra_handler):
- try:
-
- # new pytorch format is a zip file
- with zipfile.ZipFile(filename) as z:
- check_zip_filenames(filename, z.namelist())
-
- # find filename of data.pkl in zip file: '/data.pkl'
- data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
- if len(data_pkl_filenames) == 0:
- raise Exception(f"data.pkl not found in {filename}")
- if len(data_pkl_filenames) > 1:
- raise Exception(f"Multiple data.pkl found in {filename}")
- with z.open(data_pkl_filenames[0]) as file:
- unpickler = RestrictedUnpickler(file)
- unpickler.extra_handler = extra_handler
- unpickler.load()
-
- except zipfile.BadZipfile:
-
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
- with open(filename, "rb") as file:
- unpickler = RestrictedUnpickler(file)
- unpickler.extra_handler = extra_handler
- for i in range(5):
- unpickler.load()
-
-
-def load(filename, *args, **kwargs):
- return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
-
-
-def load_with_extra(filename, extra_handler=None, *args, **kwargs):
- """
- this function is intended to be used by extensions that want to load models with
- some extra classes in them that the usual unpickler would find suspicious.
-
- Use the extra_handler argument to specify a function that takes module and field name as text,
- and returns that field's value:
-
- ```python
- def extra(module, name):
- if module == 'collections' and name == 'OrderedDict':
- return collections.OrderedDict
-
- return None
-
- safe.load_with_extra('model.pt', extra_handler=extra)
- ```
-
- The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
- definitely unsafe.
- """
-
- from modules import shared
-
- try:
- if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename, extra_handler)
-
- except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
- return None
-
- except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
- return None
-
- return unsafe_torch_load(filename, *args, **kwargs)
-
-
-class Extra:
- """
- A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
- (because it's not your code making the torch.load call). The intended use is like this:
-
-```
-import torch
-from modules import safe
-
-def handler(module, name):
- if module == 'torch' and name in ['float64', 'float16']:
- return getattr(torch, name)
-
- return None
-
-with safe.Extra(handler):
- x = torch.load('model.pt')
-```
- """
-
- def __init__(self, handler):
- self.handler = handler
-
- def __enter__(self):
- global global_extra_handler
-
- assert global_extra_handler is None, 'already inside an Extra() block'
- global_extra_handler = self.handler
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- global global_extra_handler
-
- global_extra_handler = None
-
-
-unsafe_torch_load = torch.load
-torch.load = load
-global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
deleted file mode 100644
index edd0e2a7259eaeda5f1d0447d60a691debec2816..0000000000000000000000000000000000000000
--- a/modules/script_callbacks.py
+++ /dev/null
@@ -1,359 +0,0 @@
-import sys
-import traceback
-from collections import namedtuple
-import inspect
-from typing import Optional, Dict, Any
-
-from fastapi import FastAPI
-from gradio import Blocks
-
-
-def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
-
-class ImageSaveParams:
- def __init__(self, image, p, filename, pnginfo):
- self.image = image
- """the PIL image itself"""
-
- self.p = p
- """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
-
- self.filename = filename
- """name of file that the image would be saved to"""
-
- self.pnginfo = pnginfo
- """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
-
-
-class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
- self.x = x
- """Latent image representation in the process of being denoised"""
-
- self.image_cond = image_cond
- """Conditioning image"""
-
- self.sigma = sigma
- """Current sigma noise step value"""
-
- self.sampling_step = sampling_step
- """Current Sampling step number"""
-
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
-
-
-class CFGDenoisedParams:
- def __init__(self, x, sampling_step, total_sampling_steps):
- self.x = x
- """Latent image representation in the process of being denoised"""
-
- self.sampling_step = sampling_step
- """Current Sampling step number"""
-
- self.total_sampling_steps = total_sampling_steps
- """Total number of sampling steps planned"""
-
-
-class UiTrainTabParams:
- def __init__(self, txt2img_preview_params):
- self.txt2img_preview_params = txt2img_preview_params
-
-
-class ImageGridLoopParams:
- def __init__(self, imgs, cols, rows):
- self.imgs = imgs
- self.cols = cols
- self.rows = rows
-
-
-ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-callback_map = dict(
- callbacks_app_started=[],
- callbacks_model_loaded=[],
- callbacks_ui_tabs=[],
- callbacks_ui_train_tabs=[],
- callbacks_ui_settings=[],
- callbacks_before_image_saved=[],
- callbacks_image_saved=[],
- callbacks_cfg_denoiser=[],
- callbacks_cfg_denoised=[],
- callbacks_before_component=[],
- callbacks_after_component=[],
- callbacks_image_grid=[],
- callbacks_infotext_pasted=[],
- callbacks_script_unloaded=[],
- callbacks_before_ui=[],
-)
-
-
-def clear_callbacks():
- for callback_list in callback_map.values():
- callback_list.clear()
-
-
-def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callback_map['callbacks_app_started']:
- try:
- c.callback(demo, app)
- except Exception:
- report_exception(c, 'app_started_callback')
-
-
-def model_loaded_callback(sd_model):
- for c in callback_map['callbacks_model_loaded']:
- try:
- c.callback(sd_model)
- except Exception:
- report_exception(c, 'model_loaded_callback')
-
-
-def ui_tabs_callback():
- res = []
-
- for c in callback_map['callbacks_ui_tabs']:
- try:
- res += c.callback() or []
- except Exception:
- report_exception(c, 'ui_tabs_callback')
-
- return res
-
-
-def ui_train_tabs_callback(params: UiTrainTabParams):
- for c in callback_map['callbacks_ui_train_tabs']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'callbacks_ui_train_tabs')
-
-
-def ui_settings_callback():
- for c in callback_map['callbacks_ui_settings']:
- try:
- c.callback()
- except Exception:
- report_exception(c, 'ui_settings_callback')
-
-
-def before_image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_before_image_saved']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'before_image_saved_callback')
-
-
-def image_saved_callback(params: ImageSaveParams):
- for c in callback_map['callbacks_image_saved']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'image_saved_callback')
-
-
-def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callback_map['callbacks_cfg_denoiser']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'cfg_denoiser_callback')
-
-
-def cfg_denoised_callback(params: CFGDenoisedParams):
- for c in callback_map['callbacks_cfg_denoised']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'cfg_denoised_callback')
-
-
-def before_component_callback(component, **kwargs):
- for c in callback_map['callbacks_before_component']:
- try:
- c.callback(component, **kwargs)
- except Exception:
- report_exception(c, 'before_component_callback')
-
-
-def after_component_callback(component, **kwargs):
- for c in callback_map['callbacks_after_component']:
- try:
- c.callback(component, **kwargs)
- except Exception:
- report_exception(c, 'after_component_callback')
-
-
-def image_grid_callback(params: ImageGridLoopParams):
- for c in callback_map['callbacks_image_grid']:
- try:
- c.callback(params)
- except Exception:
- report_exception(c, 'image_grid')
-
-
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
- for c in callback_map['callbacks_infotext_pasted']:
- try:
- c.callback(infotext, params)
- except Exception:
- report_exception(c, 'infotext_pasted')
-
-
-def script_unloaded_callback():
- for c in reversed(callback_map['callbacks_script_unloaded']):
- try:
- c.callback()
- except Exception:
- report_exception(c, 'script_unloaded')
-
-
-def before_ui_callback():
- for c in reversed(callback_map['callbacks_before_ui']):
- try:
- c.callback()
- except Exception:
- report_exception(c, 'before_ui')
-
-
-def add_callback(callbacks, fun):
- stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
-
- callbacks.append(ScriptCallback(filename, fun))
-
-
-def remove_current_script_callbacks():
- stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
- if filename == 'unknown file':
- return
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
- callback_list.remove(callback_to_remove)
-
-
-def remove_callbacks_for_function(callback_func):
- for callback_list in callback_map.values():
- for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
- callback_list.remove(callback_to_remove)
-
-
-def on_app_started(callback):
- """register a function to be called when the webui started, the gradio `Block` component and
- fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callback_map['callbacks_app_started'], callback)
-
-
-def on_model_loaded(callback):
- """register a function to be called when the stable diffusion model is created; the model is
- passed as an argument; this function is also called when the script is reloaded. """
- add_callback(callback_map['callbacks_model_loaded'], callback)
-
-
-def on_ui_tabs(callback):
- """register a function to be called when the UI is creating new tabs.
- The function must either return a None, which means no new tabs to be added, or a list, where
- each element is a tuple:
- (gradio_component, title, elem_id)
-
- gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
- title is tab text displayed to user in the UI
- elem_id is HTML id for the tab
- """
- add_callback(callback_map['callbacks_ui_tabs'], callback)
-
-
-def on_ui_train_tabs(callback):
- """register a function to be called when the UI is creating new tabs for the train tab.
- Create your new tabs with gr.Tab.
- """
- add_callback(callback_map['callbacks_ui_train_tabs'], callback)
-
-
-def on_ui_settings(callback):
- """register a function to be called before UI settings are populated; add your settings
- by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callback_map['callbacks_ui_settings'], callback)
-
-
-def on_before_image_saved(callback):
- """register a function to be called before an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
- """
- add_callback(callback_map['callbacks_before_image_saved'], callback)
-
-
-def on_image_saved(callback):
- """register a function to be called after an image is saved to a file.
- The callback is called with one argument:
- - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
- """
- add_callback(callback_map['callbacks_image_saved'], callback)
-
-
-def on_cfg_denoiser(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoiser'], callback)
-
-
-def on_cfg_denoised(callback):
- """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
- The callback is called with one argument:
- - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
- """
- add_callback(callback_map['callbacks_cfg_denoised'], callback)
-
-
-def on_before_component(callback):
- """register a function to be called before a component is created.
- The callback is called with arguments:
- - component - gradio component that is about to be created.
- - **kwargs - args to gradio.components.IOComponent.__init__ function
-
- Use elem_id/label fields of kwargs to figure out which component it is.
- This can be useful to inject your own components somewhere in the middle of vanilla UI.
- """
- add_callback(callback_map['callbacks_before_component'], callback)
-
-
-def on_after_component(callback):
- """register a function to be called after a component is created. See on_before_component for more."""
- add_callback(callback_map['callbacks_after_component'], callback)
-
-
-def on_image_grid(callback):
- """register a function to be called before making an image grid.
- The callback is called with one argument:
- - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
- """
- add_callback(callback_map['callbacks_image_grid'], callback)
-
-
-def on_infotext_pasted(callback):
- """register a function to be called before applying an infotext.
- The callback is called with two arguments:
- - infotext: str - raw infotext.
- - result: Dict[str, any] - parsed infotext parameters.
- """
- add_callback(callback_map['callbacks_infotext_pasted'], callback)
-
-
-def on_script_unloaded(callback):
- """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
- the script did should be reverted here"""
-
- add_callback(callback_map['callbacks_script_unloaded'], callback)
-
-
-def on_before_ui(callback):
- """register a function to be called before the UI is created."""
-
- add_callback(callback_map['callbacks_before_ui'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
deleted file mode 100644
index a7d2203fc0bfd614dd40ff2dc872f6405768dff8..0000000000000000000000000000000000000000
--- a/modules/script_loading.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import os
-import sys
-import traceback
-import importlib.util
-from types import ModuleType
-
-
-def load_module(path):
- module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
- module = importlib.util.module_from_spec(module_spec)
- module_spec.loader.exec_module(module)
-
- return module
-
-
-def preload_extensions(extensions_dir, parser):
- if not os.path.isdir(extensions_dir):
- return
-
- for dirname in sorted(os.listdir(extensions_dir)):
- preload_script = os.path.join(extensions_dir, dirname, "preload.py")
- if not os.path.isfile(preload_script):
- continue
-
- try:
- module = load_module(preload_script)
- if hasattr(module, 'preload'):
- module.preload(parser)
-
- except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/scripts.py b/modules/scripts.py
deleted file mode 100644
index 24056a12f900e9cc12b94cc416123afb9e25a6e5..0000000000000000000000000000000000000000
--- a/modules/scripts.py
+++ /dev/null
@@ -1,501 +0,0 @@
-import os
-import re
-import sys
-import traceback
-from collections import namedtuple
-
-import gradio as gr
-
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
-
-AlwaysVisible = object()
-
-
-class PostprocessImageArgs:
- def __init__(self, image):
- self.image = image
-
-
-class Script:
- filename = None
- args_from = None
- args_to = None
- alwayson = False
-
- is_txt2img = False
- is_img2img = False
-
- """A gr.Group component that has all script's UI inside it"""
- group = None
-
- infotext_fields = None
- """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
- parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
- """
-
- def title(self):
- """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
-
- raise NotImplementedError()
-
- def ui(self, is_img2img):
- """this function should create gradio UI elements. See https://gradio.app/docs/#components
- The return value should be an array of all components that are used in processing.
- Values of those returned components will be passed to run() and process() functions.
- """
-
- pass
-
- def show(self, is_img2img):
- """
- is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
-
- This function should return:
- - False if the script should not be shown in UI at all
- - True if the script should be shown in UI if it's selected in the scripts dropdown
- - script.AlwaysVisible if the script should be shown in UI at all times
- """
-
- return True
-
- def run(self, p, *args):
- """
- This function is called if the script has been selected in the script dropdown.
- It must do all processing and return the Processed object with results, same as
- one returned by processing.process_images.
-
- Usually the processing is done by calling the processing.process_images function.
-
- args contains all values returned by components from ui()
- """
-
- pass
-
- def process(self, p, *args):
- """
- This function is called before processing begins for AlwaysVisible scripts.
- You can modify the processing object (p) here, inject hooks, etc.
- args contains all values returned by components from ui()
- """
-
- pass
-
- def process_batch(self, p, *args, **kwargs):
- """
- Same as process(), but called for every batch.
-
- **kwargs will have those items:
- - batch_number - index of current batch, from 0 to number of batches-1
- - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- - seeds - list of seeds for current batch
- - subseeds - list of subseeds for current batch
- """
-
- pass
-
- def postprocess_batch(self, p, *args, **kwargs):
- """
- Same as process_batch(), but called for every batch after it has been generated.
-
- **kwargs will have same items as process_batch, and also:
- - batch_number - index of current batch, from 0 to number of batches-1
- - images - torch tensor with all generated images, with values ranging from 0 to 1;
- """
-
- pass
-
- def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
- """
- Called for every image after it has been generated.
- """
-
- pass
-
- def postprocess(self, p, processed, *args):
- """
- This function is called after processing ends for AlwaysVisible scripts.
- args contains all values returned by components from ui()
- """
-
- pass
-
- def before_component(self, component, **kwargs):
- """
- Called before a component is created.
- Use elem_id/label fields of kwargs to figure out which component it is.
- This can be useful to inject your own components somewhere in the middle of vanilla UI.
- You can return created components in the ui() function to add them to the list of arguments for your processing functions
- """
-
- pass
-
- def after_component(self, component, **kwargs):
- """
- Called after a component is created. Same as above.
- """
-
- pass
-
- def describe(self):
- """unused"""
- return ""
-
- def elem_id(self, item_id):
- """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
-
- need_tabname = self.show(True) == self.show(False)
- tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
- title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
-
- return f'script_{tabname}{title}_{item_id}'
-
-
-current_basedir = paths.script_path
-
-
-def basedir():
- """returns the base directory for the current script. For scripts in the main scripts directory,
- this is the main directory (where webui.py resides), and for scripts in extensions directory
- (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
- """
- return current_basedir
-
-
-ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
-
-scripts_data = []
-postprocessing_scripts_data = []
-ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
-
-
-def list_scripts(scriptdirname, extension):
- scripts_list = []
-
- basedir = os.path.join(paths.script_path, scriptdirname)
- if os.path.exists(basedir):
- for filename in sorted(os.listdir(basedir)):
- scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
-
- for ext in extensions.active():
- scripts_list += ext.list_files(scriptdirname, extension)
-
- scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
-
- return scripts_list
-
-
-def list_files_with_name(filename):
- res = []
-
- dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
-
- for dirpath in dirs:
- if not os.path.isdir(dirpath):
- continue
-
- path = os.path.join(dirpath, filename)
- if os.path.isfile(path):
- res.append(path)
-
- return res
-
-
-def load_scripts():
- global current_basedir
- scripts_data.clear()
- postprocessing_scripts_data.clear()
- script_callbacks.clear_callbacks()
-
- scripts_list = list_scripts("scripts", ".py")
-
- syspath = sys.path
-
- def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
- if type(script_class) != type:
- continue
-
- if issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
- postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
-
- for scriptfile in sorted(scripts_list):
- try:
- if scriptfile.basedir != paths.script_path:
- sys.path = [scriptfile.basedir] + sys.path
- current_basedir = scriptfile.basedir
-
- script_module = script_loading.load_module(scriptfile.path)
- register_scripts_from_module(script_module)
-
- except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- finally:
- sys.path = syspath
- current_basedir = paths.script_path
-
-
-def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
- try:
- res = func(*args, **kwargs)
- return res
- except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- return default
-
-
-class ScriptRunner:
- def __init__(self):
- self.scripts = []
- self.selectable_scripts = []
- self.alwayson_scripts = []
- self.titles = []
- self.infotext_fields = []
-
- def initialize_scripts(self, is_img2img):
- from modules import scripts_auto_postprocessing
-
- self.scripts.clear()
- self.alwayson_scripts.clear()
- self.selectable_scripts.clear()
-
- auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
-
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
- script.is_txt2img = not is_img2img
- script.is_img2img = is_img2img
-
- visibility = script.show(script.is_img2img)
-
- if visibility == AlwaysVisible:
- self.scripts.append(script)
- self.alwayson_scripts.append(script)
- script.alwayson = True
-
- elif visibility:
- self.scripts.append(script)
- self.selectable_scripts.append(script)
-
- def setup_ui(self):
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
-
- inputs = [None]
- inputs_alwayson = [True]
-
- def create_script_ui(script, inputs, inputs_alwayson):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
-
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
-
- if controls is None:
- return
-
- for control in controls:
- control.custom_script_source = os.path.basename(script.filename)
-
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
-
- inputs += controls
- inputs_alwayson += [script.alwayson for _ in controls]
- script.args_to = len(inputs)
-
- for script in self.alwayson_scripts:
- with gr.Group() as group:
- create_script_ui(script, inputs, inputs_alwayson)
-
- script.group = group
-
- dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- inputs[0] = dropdown
-
- for script in self.selectable_scripts:
- with gr.Group(visible=False) as group:
- create_script_ui(script, inputs, inputs_alwayson)
-
- script.group = group
-
- def select_script(script_index):
- selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
-
- return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
-
- def init_field(title):
- """called when an initial value is set from ui-config.json to show script's UI components"""
-
- if title == 'None':
- return
-
- script_index = self.titles.index(title)
- self.selectable_scripts[script_index].group.visible = True
-
- dropdown.init_field = init_field
-
- dropdown.change(
- fn=select_script,
- inputs=[dropdown],
- outputs=[script.group for script in self.selectable_scripts]
- )
-
- self.script_load_ctr = 0
- def onload_script_visibility(params):
- title = params.get('Script', None)
- if title:
- title_index = self.titles.index(title)
- visibility = title_index == self.script_load_ctr
- self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
- return gr.update(visible=visibility)
- else:
- return gr.update(visible=False)
-
- self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
- self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
-
- return inputs
-
- def run(self, p, *args):
- script_index = args[0]
-
- if script_index == 0:
- return None
-
- script = self.selectable_scripts[script_index-1]
-
- if script is None:
- return None
-
- script_args = args[script.args_from:script.args_to]
- processed = script.run(p, *script_args)
-
- shared.total_tqdm.clear()
-
- return processed
-
- def process(self, p):
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.process(p, *script_args)
- except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def process_batch(self, p, **kwargs):
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.process_batch(p, *script_args, **kwargs)
- except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def postprocess(self, p, processed):
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.postprocess(p, processed, *script_args)
- except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def postprocess_batch(self, p, images, **kwargs):
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.postprocess_batch(p, *script_args, images=images, **kwargs)
- except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def postprocess_image(self, p, pp: PostprocessImageArgs):
- for script in self.alwayson_scripts:
- try:
- script_args = p.script_args[script.args_from:script.args_to]
- script.postprocess_image(p, pp, *script_args)
- except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def before_component(self, component, **kwargs):
- for script in self.scripts:
- try:
- script.before_component(component, **kwargs)
- except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def after_component(self, component, **kwargs):
- for script in self.scripts:
- try:
- script.after_component(component, **kwargs)
- except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def reload_sources(self, cache):
- for si, script in list(enumerate(self.scripts)):
- args_from = script.args_from
- args_to = script.args_to
- filename = script.filename
-
- module = cache.get(filename, None)
- if module is None:
- module = script_loading.load_module(script.filename)
- cache[filename] = module
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- self.scripts[si] = script_class()
- self.scripts[si].filename = filename
- self.scripts[si].args_from = args_from
- self.scripts[si].args_to = args_to
-
-
-scripts_txt2img = ScriptRunner()
-scripts_img2img = ScriptRunner()
-scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
-scripts_current: ScriptRunner = None
-
-
-def reload_script_body_only():
- cache = {}
- scripts_txt2img.reload_sources(cache)
- scripts_img2img.reload_sources(cache)
-
-
-def reload_scripts():
- global scripts_txt2img, scripts_img2img, scripts_postproc
-
- load_scripts()
-
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
-
-
-def IOComponent_init(self, *args, **kwargs):
- if scripts_current is not None:
- scripts_current.before_component(self, **kwargs)
-
- script_callbacks.before_component_callback(self, **kwargs)
-
- res = original_IOComponent_init(self, *args, **kwargs)
-
- script_callbacks.after_component_callback(self, **kwargs)
-
- if scripts_current is not None:
- scripts_current.after_component(self, **kwargs)
-
- return res
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-gr.components.IOComponent.__init__ = IOComponent_init
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
deleted file mode 100644
index 30d6d6586f0bc9312aee6f23934274234defb962..0000000000000000000000000000000000000000
--- a/modules/scripts_auto_postprocessing.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from modules import scripts, scripts_postprocessing, shared
-
-
-class ScriptPostprocessingForMainUI(scripts.Script):
- def __init__(self, script_postproc):
- self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
- self.postprocessing_controls = None
-
- def title(self):
- return self.script.name
-
- def show(self, is_img2img):
- return scripts.AlwaysVisible
-
- def ui(self, is_img2img):
- self.postprocessing_controls = self.script.ui()
- return self.postprocessing_controls.values()
-
- def postprocess_image(self, p, script_pp, *args):
- args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
-
- pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
- pp.info = {}
- self.script.process(pp, **args_dict)
- p.extra_generation_params.update(pp.info)
- script_pp.image = pp.image
-
-
-def create_auto_preprocessing_script_data():
- from modules import scripts
-
- res = []
-
- for name in shared.opts.postprocessing_enable_in_main_ui:
- script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
- if script is None:
- continue
-
- constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
- res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
-
- return res
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
deleted file mode 100644
index ce0ebb611d1bd3c6deb7496b107683bf9dbb6050..0000000000000000000000000000000000000000
--- a/modules/scripts_postprocessing.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import os
-import gradio as gr
-
-from modules import errors, shared
-
-
-class PostprocessedImage:
- def __init__(self, image):
- self.image = image
- self.info = {}
-
-
-class ScriptPostprocessing:
- filename = None
- controls = None
- args_from = None
- args_to = None
-
- order = 1000
- """scripts will be ordred by this value in postprocessing UI"""
-
- name = None
- """this function should return the title of the script."""
-
- group = None
- """A gr.Group component that has all script's UI inside it"""
-
- def ui(self):
- """
- This function should create gradio UI elements. See https://gradio.app/docs/#components
- The return value should be a dictionary that maps parameter names to components used in processing.
- Values of those components will be passed to process() function.
- """
-
- pass
-
- def process(self, pp: PostprocessedImage, **args):
- """
- This function is called to postprocess the image.
- args contains a dictionary with all values returned by components from ui()
- """
-
- pass
-
- def image_changed(self):
- pass
-
-
-
-
-def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
- try:
- res = func(*args, **kwargs)
- return res
- except Exception as e:
- errors.display(e, f"calling {filename}/{funcname}")
-
- return default
-
-
-class ScriptPostprocessingRunner:
- def __init__(self):
- self.scripts = None
- self.ui_created = False
-
- def initialize_scripts(self, scripts_data):
- self.scripts = []
-
- for script_class, path, basedir, script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
-
- if script.name == "Simple Upscale":
- continue
-
- self.scripts.append(script)
-
- def create_script_ui(self, script, inputs):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
-
- script.controls = wrap_call(script.ui, script.filename, "ui")
-
- for control in script.controls.values():
- control.custom_script_source = os.path.basename(script.filename)
-
- inputs += list(script.controls.values())
- script.args_to = len(inputs)
-
- def scripts_in_preferred_order(self):
- if self.scripts is None:
- import modules.scripts
- self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
-
- scripts_order = shared.opts.postprocessing_operation_order
-
- def script_score(name):
- for i, possible_match in enumerate(scripts_order):
- if possible_match == name:
- return i
-
- return len(self.scripts)
-
- script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
-
- return sorted(self.scripts, key=lambda x: script_scores[x.name])
-
- def setup_ui(self):
- inputs = []
-
- for script in self.scripts_in_preferred_order():
- with gr.Box() as group:
- self.create_script_ui(script, inputs)
-
- script.group = group
-
- self.ui_created = True
- return inputs
-
- def run(self, pp: PostprocessedImage, args):
- for script in self.scripts_in_preferred_order():
- shared.state.job = script.name
-
- script_args = args[script.args_from:script.args_to]
-
- process_args = {}
- for (name, component), value in zip(script.controls.items(), script_args):
- process_args[name] = value
-
- script.process(pp, **process_args)
-
- def create_args_for_run(self, scripts_args):
- if not self.ui_created:
- with gr.Blocks(analytics_enabled=False):
- self.setup_ui()
-
- scripts = self.scripts_in_preferred_order()
- args = [None] * max([x.args_to for x in scripts])
-
- for script in scripts:
- script_args_dict = scripts_args.get(script.name, None)
- if script_args_dict is not None:
-
- for i, name in enumerate(script.controls):
- args[script.args_from + i] = script_args_dict.get(name, None)
-
- return args
-
- def image_changed(self):
- for script in self.scripts_in_preferred_order():
- script.image_changed()
-
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
deleted file mode 100644
index c4a09d15da1369d140552a5309d04b83c24654b5..0000000000000000000000000000000000000000
--- a/modules/sd_disable_initialization.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import ldm.modules.encoders.modules
-import open_clip
-import torch
-import transformers.utils.hub
-
-
-class DisableInitialization:
- """
- When an object of this class enters a `with` block, it starts:
- - preventing torch's layer initialization functions from working
- - changes CLIP and OpenCLIP to not download model weights
- - changes CLIP to not make requests to check if there is a new version of a file you already have
-
- When it leaves the block, it reverts everything to how it was before.
-
- Use it like this:
- ```
- with DisableInitialization():
- do_things()
- ```
- """
-
- def __init__(self, disable_clip=True):
- self.replaced = []
- self.disable_clip = disable_clip
-
- def replace(self, obj, field, func):
- original = getattr(obj, field, None)
- if original is None:
- return None
-
- self.replaced.append((obj, field, original))
- setattr(obj, field, func)
-
- return original
-
- def __enter__(self):
- def do_nothing(*args, **kwargs):
- pass
-
- def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
- return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
-
- def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
- res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
- res.name_or_path = pretrained_model_name_or_path
- return res
-
- def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
- args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
- return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
-
- def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
-
- # this file is always 404, prevent making request
- if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
- return None
-
- try:
- res = original(url, *args, local_files_only=True, **kwargs)
- if res is None:
- res = original(url, *args, local_files_only=False, **kwargs)
- return res
- except Exception as e:
- return original(url, *args, local_files_only=False, **kwargs)
-
- def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
- return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
-
- def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
- return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
-
- def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
- return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
-
- self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
- self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
- self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
-
- if self.disable_clip:
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- for obj, field, original in self.replaced:
- setattr(obj, field, original)
-
- self.replaced.clear()
-
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
deleted file mode 100644
index 794767831b4709522d52f4fe8bf8c50ef03a3cb1..0000000000000000000000000000000000000000
--- a/modules/sd_hijack.py
+++ /dev/null
@@ -1,264 +0,0 @@
-import torch
-from torch.nn.functional import silu
-from types import MethodType
-
-import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
-from modules.hypernetworks import hypernetwork
-from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
-
-import ldm.modules.attention
-import ldm.modules.diffusionmodules.model
-import ldm.modules.diffusionmodules.openaimodel
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-import ldm.modules.encoders.modules
-
-attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
-diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
-diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-
-# new memory efficient cross attention blocks do not support hypernets and we already
-# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
-ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
-ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
-
-# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
-
-
-def apply_optimizations():
- undo_optimizations()
-
- ldm.modules.diffusionmodules.model.nonlinearity = silu
- ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-
- optimization_method = None
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
- print("Applying xformers cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- elif cmd_opts.opt_sub_quad_attention:
- print("Applying sub-quadratic cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- elif cmd_opts.opt_split_attention_v1:
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization (Doggettx).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'Doggettx'
-
- return optimization_method
-
-
-def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
- ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-
-
-def fix_checkpoint():
- """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
- checkpoints to be added when not training (there's a warning)"""
-
- pass
-
-
-def weighted_loss(sd_model, pred, target, mean=True):
- #Calculate the weight normally, but ignore the mean
- loss = sd_model._old_get_loss(pred, target, mean=False)
-
- #Check if we have weights available
- weight = getattr(sd_model, '_custom_loss_weight', None)
- if weight is not None:
- loss *= weight
-
- #Return the loss, as mean if specified
- return loss.mean() if mean else loss
-
-def weighted_forward(sd_model, x, c, w, *args, **kwargs):
- try:
- #Temporarily append weights to a place accessible during loss calc
- sd_model._custom_loss_weight = w
-
- #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
- #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
- if not hasattr(sd_model, '_old_get_loss'):
- sd_model._old_get_loss = sd_model.get_loss
- sd_model.get_loss = MethodType(weighted_loss, sd_model)
-
- #Run the standard forward function, but with the patched 'get_loss'
- return sd_model.forward(x, c, *args, **kwargs)
- finally:
- try:
- #Delete temporary weights if appended
- del sd_model._custom_loss_weight
- except AttributeError as e:
- pass
-
- #If we have an old loss function, reset the loss function to the original one
- if hasattr(sd_model, '_old_get_loss'):
- sd_model.get_loss = sd_model._old_get_loss
- del sd_model._old_get_loss
-
-def apply_weighted_forward(sd_model):
- #Add new function 'weighted_forward' that can be called to calc weighted loss
- sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
-
-def undo_weighted_forward(sd_model):
- try:
- del sd_model.weighted_forward
- except AttributeError as e:
- pass
-
-
-class StableDiffusionModelHijack:
- fixes = None
- comments = []
- layers = None
- circular_enabled = False
- clip = None
- optimization_method = None
-
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
- def __init__(self):
- self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
-
- def hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- model_embeddings = m.cond_stage_model.roberta.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
- m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
-
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
- elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
- m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
- m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
- apply_weighted_forward(m)
- if m.cond_stage_key == "edit":
- sd_hijack_unet.hijack_ddpm_edit()
-
- self.optimization_method = apply_optimizations()
-
- self.clip = m.cond_stage_model
-
- def flatten(el):
- flattened = [flatten(children) for children in el.children()]
- res = [el]
- for c in flattened:
- res += c
- return res
-
- self.layers = flatten(m)
-
- def undo_hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
-
- elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
- m.cond_stage_model = m.cond_stage_model.wrapped
-
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
- if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
- model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
- elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
- m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
- m.cond_stage_model = m.cond_stage_model.wrapped
-
- undo_optimizations()
- undo_weighted_forward(m)
-
- self.apply_circular(False)
- self.layers = None
- self.clip = None
-
- def apply_circular(self, enable):
- if self.circular_enabled == enable:
- return
-
- self.circular_enabled = enable
-
- for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
- layer.padding_mode = 'circular' if enable else 'zeros'
-
- def clear_comments(self):
- self.comments = []
-
- def get_prompt_lengths(self, text):
- _, token_count = self.clip.process_texts([text])
-
- return token_count, self.clip.get_target_prompt_token_count(token_count)
-
-
-class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
- super().__init__()
- self.wrapped = wrapped
- self.embeddings = embeddings
-
- def forward(self, input_ids):
- batch_fixes = self.embeddings.fixes
- self.embeddings.fixes = None
-
- inputs_embeds = self.wrapped(input_ids)
-
- if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
- return inputs_embeds
-
- vecs = []
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
- emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
-
- vecs.append(tensor)
-
- return torch.stack(vecs)
-
-
-def add_circular_option_to_conv_2d():
- conv2d_constructor = torch.nn.Conv2d.__init__
-
- def conv2d_constructor_circular(self, *args, **kwargs):
- return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
-
- torch.nn.Conv2d.__init__ = conv2d_constructor_circular
-
-
-model_hijack = StableDiffusionModelHijack()
-
-
-def register_buffer(self, name, attr):
- """
- Fix register buffer bug for Mac OS.
- """
-
- if type(attr) == torch.Tensor:
- if attr.device != devices.device:
- attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
-
- setattr(self, name, attr)
-
-
-ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
-ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
deleted file mode 100644
index 2604d969f910ffdd65aff66acc0b6ab09b793b38..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_checkpoint.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from torch.utils.checkpoint import checkpoint
-
-import ldm.modules.attention
-import ldm.modules.diffusionmodules.openaimodel
-
-
-def BasicTransformerBlock_forward(self, x, context=None):
- return checkpoint(self._forward, x, context)
-
-
-def AttentionBlock_forward(self, x):
- return checkpoint(self._forward, x)
-
-
-def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb)
-
-
-stored = []
-
-
-def add():
- if len(stored) != 0:
- return
-
- stored.extend([
- ldm.modules.attention.BasicTransformerBlock.forward,
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
- ])
-
- ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
-
-
-def remove():
- if len(stored) == 0:
- return
-
- ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
-
- stored.clear()
-
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
deleted file mode 100644
index 9fa5c5c50cd559c5f5da89df584eca17f1cf24b0..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_clip.py
+++ /dev/null
@@ -1,317 +0,0 @@
-import math
-from collections import namedtuple
-
-import torch
-
-from modules import prompt_parser, devices, sd_hijack
-from modules.shared import opts
-
-
-class PromptChunk:
- """
- This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
- If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
- Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
- so just 75 tokens from prompt.
- """
-
- def __init__(self):
- self.tokens = []
- self.multipliers = []
- self.fixes = []
-
-
-PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
-"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
-chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
-are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
-
-
-class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
- """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
- have unlimited prompt length and assign weights to tokens in prompt.
- """
-
- def __init__(self, wrapped, hijack):
- super().__init__()
-
- self.wrapped = wrapped
- """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
- depending on model."""
-
- self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
- self.chunk_length = 75
-
- def empty_chunk(self):
- """creates an empty PromptChunk and returns it"""
-
- chunk = PromptChunk()
- chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
- chunk.multipliers = [1.0] * (self.chunk_length + 2)
- return chunk
-
- def get_target_prompt_token_count(self, token_count):
- """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
-
- return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
-
- def tokenize(self, texts):
- """Converts a batch of texts into a batch of token ids"""
-
- raise NotImplementedError
-
- def encode_with_transformers(self, tokens):
- """
- converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
- All python lists with tokens are assumed to have same length, usually 77.
- if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
- model - can be 768 and 1024.
- Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
- """
-
- raise NotImplementedError
-
- def encode_embedding_init_text(self, init_text, nvpt):
- """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
- transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
-
- raise NotImplementedError
-
- def tokenize_line(self, line):
- """
- this transforms a single prompt into a list of PromptChunk objects - as many as needed to
- represent the prompt.
- Returns the list and the total number of tokens in the prompt.
- """
-
- if opts.enable_emphasis:
- parsed = prompt_parser.parse_prompt_attention(line)
- else:
- parsed = [[line, 1.0]]
-
- tokenized = self.tokenize([text for text, _ in parsed])
-
- chunks = []
- chunk = PromptChunk()
- token_count = 0
- last_comma = -1
-
- def next_chunk(is_last=False):
- """puts current chunk into the list of results and produces the next one - empty;
- if is_last is true, tokens tokens at the end won't add to token_count"""
- nonlocal token_count
- nonlocal last_comma
- nonlocal chunk
-
- if is_last:
- token_count += len(chunk.tokens)
- else:
- token_count += self.chunk_length
-
- to_add = self.chunk_length - len(chunk.tokens)
- if to_add > 0:
- chunk.tokens += [self.id_end] * to_add
- chunk.multipliers += [1.0] * to_add
-
- chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
- chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
-
- last_comma = -1
- chunks.append(chunk)
- chunk = PromptChunk()
-
- for tokens, (text, weight) in zip(tokenized, parsed):
- if text == 'BREAK' and weight == -1:
- next_chunk()
- continue
-
- position = 0
- while position < len(tokens):
- token = tokens[position]
-
- if token == self.comma_token:
- last_comma = len(chunk.tokens)
-
- # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
- # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
- elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
- break_location = last_comma + 1
-
- reloc_tokens = chunk.tokens[break_location:]
- reloc_mults = chunk.multipliers[break_location:]
-
- chunk.tokens = chunk.tokens[:break_location]
- chunk.multipliers = chunk.multipliers[:break_location]
-
- next_chunk()
- chunk.tokens = reloc_tokens
- chunk.multipliers = reloc_mults
-
- if len(chunk.tokens) == self.chunk_length:
- next_chunk()
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
- if embedding is None:
- chunk.tokens.append(token)
- chunk.multipliers.append(weight)
- position += 1
- continue
-
- emb_len = int(embedding.vec.shape[0])
- if len(chunk.tokens) + emb_len > self.chunk_length:
- next_chunk()
-
- chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
-
- chunk.tokens += [0] * emb_len
- chunk.multipliers += [weight] * emb_len
- position += embedding_length_in_tokens
-
- if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk(is_last=True)
-
- return chunks, token_count
-
- def process_texts(self, texts):
- """
- Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
- length, in tokens, of all texts.
- """
-
- token_count = 0
-
- cache = {}
- batch_chunks = []
- for line in texts:
- if line in cache:
- chunks = cache[line]
- else:
- chunks, current_token_count = self.tokenize_line(line)
- token_count = max(current_token_count, token_count)
-
- cache[line] = chunks
-
- batch_chunks.append(chunks)
-
- return batch_chunks, token_count
-
- def forward(self, texts):
- """
- Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
- Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
- An example shape returned by this function can be: (2, 77, 768).
- Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
- is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
- """
-
- if opts.use_old_emphasis_implementation:
- import modules.sd_hijack_clip_old
- return modules.sd_hijack_clip_old.forward_old(self, texts)
-
- batch_chunks, token_count = self.process_texts(texts)
-
- used_embeddings = {}
- chunk_count = max([len(x) for x in batch_chunks])
-
- zs = []
- for i in range(chunk_count):
- batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
-
- tokens = [x.tokens for x in batch_chunk]
- multipliers = [x.multipliers for x in batch_chunk]
- self.hijack.fixes = [x.fixes for x in batch_chunk]
-
- for fixes in self.hijack.fixes:
- for position, embedding in fixes:
- used_embeddings[embedding.name] = embedding
-
- z = self.process_tokens(tokens, multipliers)
- zs.append(z)
-
- if len(used_embeddings) > 0:
- embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
- self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
-
- return torch.hstack(zs)
-
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- """
- sends one single prompt chunk to be encoded by transformers neural network.
- remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
- there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
- Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
- corresponds to one token.
- """
- tokens = torch.asarray(remade_batch_tokens).to(devices.device)
-
- # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
- if self.id_end != self.id_pad:
- for batch_pos in range(len(remade_batch_tokens)):
- index = remade_batch_tokens[batch_pos].index(self.id_end)
- tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
-
- z = self.encode_with_transformers(tokens)
-
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
- original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- new_mean = z.mean()
- z = z * (original_mean / new_mean)
-
- return z
-
-
-class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
- self.tokenizer = wrapped.tokenizer
-
- vocab = self.tokenizer.get_vocab()
-
- self.comma_token = vocab.get(',', None)
-
- self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
- for text, ident in tokens_with_parens:
- mult = 1.0
- for c in text:
- if c == '[':
- mult /= 1.1
- if c == ']':
- mult *= 1.1
- if c == '(':
- mult *= 1.1
- if c == ')':
- mult /= 1.1
-
- if mult != 1.0:
- self.token_mults[ident] = mult
-
- self.id_start = self.wrapped.tokenizer.bos_token_id
- self.id_end = self.wrapped.tokenizer.eos_token_id
- self.id_pad = self.id_end
-
- def tokenize(self, texts):
- tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
-
- return tokenized
-
- def encode_with_transformers(self, tokens):
- outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
-
- if opts.CLIP_stop_at_last_layers > 1:
- z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
- z = self.wrapped.transformer.text_model.final_layer_norm(z)
- else:
- z = outputs.last_hidden_state
-
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt):
- embedding_layer = self.wrapped.transformer.text_model.embeddings
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
-
- return embedded
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
deleted file mode 100644
index 6d9fbbe6ca1edfe420c8d5cbc737cdfee4a73622..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_clip_old.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from modules import sd_hijack_clip
-from modules import shared
-
-
-def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
-
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
-
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
-
-def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
deleted file mode 100644
index 55a2ce4d19200acafd79e6fce7e017c4abc50a73..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_inpainting.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import os
-import torch
-
-from einops import repeat
-from omegaconf import ListConfig
-
-import ldm.models.diffusion.ddpm
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-from ldm.models.diffusion.ddpm import LatentDiffusion
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.models.diffusion.ddim import DDIMSampler, noise_like
-from ldm.models.diffusion.sampling_util import norm_thresholding
-
-
-@torch.no_grad()
-def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
- b, *_, device = *x.shape, x.device
-
- def get_model_output(x, t):
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
-
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
-
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- return e_t
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- if dynamic_threshold is not None:
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- e_t = get_model_output(x, t)
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = get_model_output(x_prev, t_next)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- elif len(old_eps) >= 3:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- return x_prev, pred_x0, e_t
-
-
-def do_inpainting_hijack():
- # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
- ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
deleted file mode 100644
index 3c727d3b75332508629458d23f7fb86cc9ede44b..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_ip2p.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import collections
-import os.path
-import sys
-import gc
-import time
-
-def should_hijack_ip2p(checkpoint_info):
- from modules import sd_models_config
-
- ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
-
- return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
deleted file mode 100644
index f733e8529fb6cd68d97b2f255bc705d0cd949fbc..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_open_clip.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import open_clip.tokenizer
-import torch
-
-from modules import sd_hijack_clip, devices
-from modules.shared import opts
-
-tokenizer = open_clip.tokenizer._tokenizer
-
-
-class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
-
- self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0]
- self.id_start = tokenizer.encoder[""]
- self.id_end = tokenizer.encoder[""]
- self.id_pad = 0
-
- def tokenize(self, texts):
- assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
-
- tokenized = [tokenizer.encode(text) for text in texts]
-
- return tokenized
-
- def encode_with_transformers(self, tokens):
- # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
- z = self.wrapped.encode_with_transformer(tokens)
-
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt):
- ids = tokenizer.encode(init_text)
- ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
- embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
-
- return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
deleted file mode 100644
index c02d954c7ab040be940a3650a64d1d1978409fb5..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_optimizations.py
+++ /dev/null
@@ -1,444 +0,0 @@
-import math
-import sys
-import traceback
-import psutil
-
-import torch
-from torch import einsum
-
-from ldm.util import default
-from einops import rearrange
-
-from modules import shared, errors, devices
-from modules.hypernetworks import hypernetwork
-
-from .sub_quadratic_attention import efficient_dot_product_attention
-
-
-if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
- try:
- import xformers.ops
- shared.xformers_available = True
- except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
-
-def get_available_vram():
- if shared.device.type == 'cuda':
- stats = torch.cuda.memory_stats(shared.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
- return mem_free_total
- else:
- return psutil.virtual_memory().available
-
-
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
- del context, context_k, context_v, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
-
- s2 = s1.softmax(dim=-1)
- del s1
-
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
- del q, k, v
-
- r1 = r1.to(dtype)
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
-
- dtype = q_in.dtype
- if shared.opts.upcast_attn:
- q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- k_in = k_in * self.scale
-
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- mem_free_total = get_available_vram()
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
-
- del q, k, v
-
- r1 = r1.to(dtype)
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
-mem_total_gb = psutil.virtual_memory().total // (1 << 30)
-
-def einsum_op_compvis(q, k, v):
- s = einsum('b i d, b j d -> b i j', q, k)
- s = s.softmax(dim=-1, dtype=s.dtype)
- return einsum('b i j, b j d -> b i d', s, v)
-
-def einsum_op_slice_0(q, k, v, slice_size):
- r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[0], slice_size):
- end = i + slice_size
- r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
- return r
-
-def einsum_op_slice_1(q, k, v, slice_size):
- r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
- return r
-
-def einsum_op_mps_v1(q, k, v):
- if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
- return einsum_op_compvis(q, k, v)
- else:
- slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
- if slice_size % 4096 == 0:
- slice_size -= 1
- return einsum_op_slice_1(q, k, v, slice_size)
-
-def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
- return einsum_op_compvis(q, k, v)
- else:
- return einsum_op_slice_0(q, k, v, 1)
-
-def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
- size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
- if size_mb <= max_tensor_mb:
- return einsum_op_compvis(q, k, v)
- div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
- if div <= q.shape[0]:
- return einsum_op_slice_0(q, k, v, q.shape[0] // div)
- return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
-
-def einsum_op_cuda(q, k, v):
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
- # Divide factor of safety as there's copying and fragmentation
- return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
-
-def einsum_op(q, k, v):
- if q.device.type == 'cuda':
- return einsum_op_cuda(q, k, v)
-
- if q.device.type == 'mps':
- if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
- return einsum_op_mps_v1(q, k, v)
- return einsum_op_mps_v2(q, k, v)
-
- # Smaller slices are faster due to L2/L3/SLC caches.
- # Tested on i7 with 8MB L3 cache.
- return einsum_op_tensor_mem(q, k, v, 32)
-
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
- del context, context_k, context_v, x
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
-
- with devices.without_autocast(disable=not shared.opts.upcast_attn):
- k = k * self.scale
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r = einsum_op(q, k, v)
- r = r.to(dtype)
- return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
-
-# -- End of code from https://github.com/invoke-ai/InvokeAI --
-
-
-# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
-# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None):
- assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
-
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
- del context, context_k, context_v, x
-
- q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
- k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
- v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
-
- x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
-
- x = x.to(dtype)
-
- x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
-
- out_proj, dropout = self.to_out
- x = out_proj(x)
- x = dropout(x)
-
- return x
-
-def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
- bytes_per_token = torch.finfo(q.dtype).bits//8
- batch_x_heads, q_tokens, _ = q.shape
- _, k_tokens, _ = k.shape
- qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
-
- if chunk_threshold is None:
- chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
- elif chunk_threshold == 0:
- chunk_threshold_bytes = None
- else:
- chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
-
- if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
- kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
- elif kv_chunk_size_min == 0:
- kv_chunk_size_min = None
-
- if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
- # the big matmul fits into our memory limit; do everything in 1 chunk,
- # i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
- kv_chunk_size = k_tokens
-
- with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
- q,
- k,
- v,
- query_chunk_size=q_chunk_size,
- kv_chunk_size=kv_chunk_size,
- kv_chunk_size_min = kv_chunk_size_min,
- use_checkpoint=use_checkpoint,
- )
-
-
-def get_xformers_flash_attention_op(q, k, v):
- if not shared.cmd_opts.xformers_flash_attention:
- return None
-
- try:
- flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
- fw, bw = flash_attention_op
- if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
- return flash_attention_op
- except Exception as e:
- errors.display_once(e, "enabling flash attention")
-
- return None
-
-
-def xformers_attention_forward(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
- k_in = self.to_k(context_k)
- v_in = self.to_v(context_v)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
-
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
-
- out = out.to(dtype)
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- return self.to_out(out)
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- mem_free_total = get_available_vram()
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
-
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
-
- h2 = h_.reshape(b, c, h, w)
- del h_
-
- h3 = self.proj_out(h2)
- del h2
-
- h3 += x
-
- return h3
-
-def xformers_attnblock_forward(self, x):
- try:
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
- dtype = q.dtype
- if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
- out = out.to(dtype)
- out = rearrange(out, 'b (h w) c -> b c h w', h=h)
- out = self.proj_out(out)
- return x + out
- except NotImplementedError:
- return cross_attention_attnblock_forward(self, x)
-
-def sub_quad_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
- b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
- out = rearrange(out, 'b (h w) c -> b c h w', h=h)
- out = self.proj_out(out)
- return x + out
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
deleted file mode 100644
index 843ab66cfbd07e2b757a226584cc51656ff3f448..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_unet.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import torch
-from packaging import version
-
-from modules import devices
-from modules.sd_hijack_utils import CondFunc
-
-
-class TorchHijackForUnet:
- """
- This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
- this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
- """
-
- def __getattr__(self, item):
- if item == 'cat':
- return self.cat
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
-
- def cat(self, tensors, *args, **kwargs):
- if len(tensors) == 2:
- a, b = tensors
- if a.shape[-2:] != b.shape[-2:]:
- a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
-
- tensors = (a, b)
-
- return torch.cat(tensors, *args, **kwargs)
-
-
-th = TorchHijackForUnet()
-
-
-# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
-def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
-
- if isinstance(cond, dict):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
-
- with devices.autocast():
- return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
-
-
-class GELUHijack(torch.nn.GELU, torch.nn.Module):
- def __init__(self, *args, **kwargs):
- torch.nn.GELU.__init__(self, *args, **kwargs)
- def forward(self, x):
- if devices.unet_needs_upcast:
- return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
- else:
- return torch.nn.GELU.forward(self, x)
-
-
-ddpm_edit_hijack = None
-def hijack_ddpm_edit():
- global ddpm_edit_hijack
- if not ddpm_edit_hijack:
- CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
- CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
- ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-
-
-unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
-if version.parse(torch.__version__) <= version.parse("1.13.1"):
- CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
- CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
- CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
-
-first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
-first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
deleted file mode 100644
index f8684475ec5c85d7a6f5aa18238a3a5003b17234..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_utils.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import importlib
-
-class CondFunc:
- def __new__(cls, orig_func, sub_func, cond_func):
- self = super(CondFunc, cls).__new__(cls)
- if isinstance(orig_func, str):
- func_path = orig_func.split('.')
- for i in range(len(func_path)-1, -1, -1):
- try:
- resolved_obj = importlib.import_module('.'.join(func_path[:i]))
- break
- except ImportError:
- pass
- for attr_name in func_path[i:-1]:
- resolved_obj = getattr(resolved_obj, attr_name)
- orig_func = getattr(resolved_obj, func_path[-1])
- setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
- self.__init__(orig_func, sub_func, cond_func)
- return lambda *args, **kwargs: self(*args, **kwargs)
- def __init__(self, orig_func, sub_func, cond_func):
- self.__orig_func = orig_func
- self.__sub_func = sub_func
- self.__cond_func = cond_func
- def __call__(self, *args, **kwargs):
- if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
- return self.__sub_func(self.__orig_func, *args, **kwargs)
- else:
- return self.__orig_func(*args, **kwargs)
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
deleted file mode 100644
index 4ac51c386fdb72610053e472c55104038ab4e6ba..0000000000000000000000000000000000000000
--- a/modules/sd_hijack_xlmr.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import open_clip.tokenizer
-import torch
-
-from modules import sd_hijack_clip, devices
-from modules.shared import opts
-
-
-class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
- def __init__(self, wrapped, hijack):
- super().__init__(wrapped, hijack)
-
- self.id_start = wrapped.config.bos_token_id
- self.id_end = wrapped.config.eos_token_id
- self.id_pad = wrapped.config.pad_token_id
-
- self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
-
- def encode_with_transformers(self, tokens):
- # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
- # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
- # layer to work with - you have to use the last
-
- attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
- features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
- z = features['projection_state']
-
- return z
-
- def encode_embedding_init_text(self, init_text, nvpt):
- embedding_layer = self.wrapped.roberta.embeddings
- ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
-
- return embedded
diff --git a/modules/sd_models.py b/modules/sd_models.py
deleted file mode 100644
index 4f9323038f05280d43a8fc40919ec996341da328..0000000000000000000000000000000000000000
--- a/modules/sd_models.py
+++ /dev/null
@@ -1,496 +0,0 @@
-import collections
-import os.path
-import sys
-import gc
-import torch
-import re
-import safetensors.torch
-from omegaconf import OmegaConf
-from os import mkdir
-from urllib import request
-import ldm.modules.midas as midas
-
-from ldm.util import instantiate_from_config
-
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
-from modules.sd_hijack_inpainting import do_inpainting_hijack
-from modules.timer import Timer
-
-model_dir = "Stable-diffusion"
-model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
-
-checkpoints_list = {}
-checkpoint_alisases = {}
-checkpoints_loaded = collections.OrderedDict()
-
-
-class CheckpointInfo:
- def __init__(self, filename):
- self.filename = filename
- abspath = os.path.abspath(filename)
-
- shared.cmd_opts.ckpt_dir='/content/gdrive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/model.ckpt'
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(filename)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- self.name = name
- self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
- self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
- self.hash = model_hash(filename)
-
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
- self.shorthash = self.sha256[0:10] if self.sha256 else None
-
- self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
-
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
-
- def register(self):
- checkpoints_list[self.title] = self
- for id in self.ids:
- checkpoint_alisases[id] = self
-
- def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
- if self.sha256 is None:
- return
-
- self.shorthash = self.sha256[0:10]
-
- if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
-
- checkpoints_list.pop(self.title)
- self.title = f'{self.name} [{self.shorthash}]'
- self.register()
-
- return self.shorthash
-
-
-try:
- # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
-
- logging.set_verbosity_error()
-except Exception:
- pass
-
-
-def setup_model():
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- list_models()
- enable_midas_autodownload()
-
-
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
-
-
-def list_models():
- checkpoints_list.clear()
- checkpoint_alisases.clear()
-
- cmd_ckpt = shared.cmd_opts.ckpt
- if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
- model_url = None
- else:
- model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
-
- model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
-
- if os.path.exists(cmd_ckpt):
- checkpoint_info = CheckpointInfo(cmd_ckpt)
- checkpoint_info.register()
-
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
- print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
-
- for filename in model_list:
- checkpoint_info = CheckpointInfo(filename)
- checkpoint_info.register()
-
-
-def get_closet_checkpoint_match(search_string):
- checkpoint_info = checkpoint_alisases.get(search_string, None)
- if checkpoint_info is not None:
- return checkpoint_info
-
- found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
- if found:
- return found[0]
-
- return None
-
-
-def model_hash(filename):
- """old hash that only looks at a small part of the file and is prone to collisions"""
-
- try:
- with open(filename, "rb") as file:
- import hashlib
- m = hashlib.sha256()
-
- file.seek(0x100000)
- m.update(file.read(0x10000))
- return m.hexdigest()[0:8]
- except FileNotFoundError:
- return 'NOFILE'
-
-
-def select_checkpoint():
- model_checkpoint = shared.opts.sd_model_checkpoint
-
- checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
- if checkpoint_info is not None:
- return checkpoint_info
-
- if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
- if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
- if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
-
- checkpoint_info = next(iter(checkpoints_list.values()))
- if model_checkpoint is not None:
- print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
-
- return checkpoint_info
-
-
-chckpoint_dict_replacements = {
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
-}
-
-
-def transform_checkpoint_dict_key(k):
- for text, replacement in chckpoint_dict_replacements.items():
- if k.startswith(text):
- k = replacement + k[len(text):]
-
- return k
-
-
-def get_state_dict_from_checkpoint(pl_sd):
- pl_sd = pl_sd.pop("state_dict", pl_sd)
- pl_sd.pop("state_dict", None)
-
- sd = {}
- for k, v in pl_sd.items():
- new_key = transform_checkpoint_dict_key(k)
-
- if new_key is not None:
- sd[new_key] = v
-
- pl_sd.clear()
- pl_sd.update(sd)
-
- return pl_sd
-
-
-def read_state_dict(checkpoint_file, print_global_state=False, map_location='cuda'):
- _, extension = os.path.splitext(checkpoint_file)
- if extension.lower() == ".safetensors":
- device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
- else:
- pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
-
- if print_global_state and "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
-
- sd = get_state_dict_from_checkpoint(pl_sd)
- return sd
-
-
-def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
- sd_model_hash = checkpoint_info.calculate_shorthash()
- timer.record("calculate hash")
-
- if checkpoint_info in checkpoints_loaded:
- # use checkpoint cache
- print(f"Loading weights [{sd_model_hash}] from cache")
- return checkpoints_loaded[checkpoint_info]
-
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- res = read_state_dict(checkpoint_info.filename)
- timer.record("load weights from disk")
-
- return res
-
-
-def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
- sd_model_hash = checkpoint_info.calculate_shorthash()
- timer.record("calculate hash")
-
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
-
- if state_dict is None:
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
-
- model.load_state_dict(state_dict, strict=False)
- del state_dict
- timer.record("apply weights to model")
-
- if shared.opts.sd_checkpoint_cache > 0:
- # cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
-
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
- timer.record("apply channels_last")
-
- if not shared.cmd_opts.no_half:
- vae = model.first_stage_model
- depth_model = getattr(model, 'depth_model', None)
-
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
- if shared.cmd_opts.no_half_vae:
- model.first_stage_model = None
- # with --upcast-sampling, don't convert the depth model weights to float16
- if shared.cmd_opts.upcast_sampling and depth_model:
- model.depth_model = None
-
- model.half()
- model.first_stage_model = vae
- if depth_model:
- model.depth_model = depth_model
-
- timer.record("apply half()")
-
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- devices.dtype_unet = model.model.diffusion_model.dtype
- devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
-
- model.first_stage_model.to(devices.dtype_vae)
- timer.record("apply dtype to VAE")
-
- # clean up cache if limit is reached
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False)
-
- model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_info.filename
- model.sd_checkpoint_info = checkpoint_info
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
-
- model.logvar = model.logvar.to(devices.device) # fix for training
-
- sd_vae.delete_base_vae()
- sd_vae.clear_loaded_vae()
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
- sd_vae.load_vae(model, vae_file, vae_source)
- timer.record("load VAE")
-
-
-def enable_midas_autodownload():
- """
- Gives the ldm.modules.midas.api.load_model function automatic downloading.
-
- When the 512-depth-ema model, and other future models like it, is loaded,
- it calls midas.api.load_model to load the associated midas depth model.
- This function applies a wrapper to download the model to the correct
- location automatically.
- """
-
- midas_path = os.path.join(paths.models_path, 'midas')
-
- # stable-diffusion-stability-ai hard-codes the midas model path to
- # a location that differs from where other scripts using this model look.
- # HACK: Overriding the path here.
- for k, v in midas.api.ISL_PATHS.items():
- file_name = os.path.basename(v)
- midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
-
- midas_urls = {
- "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
- "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
- "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
- "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
- }
-
- midas.api.load_model_inner = midas.api.load_model
-
- def load_model_wrapper(model_type):
- path = midas.api.ISL_PATHS[model_type]
- if not os.path.exists(path):
- if not os.path.exists(midas_path):
- mkdir(midas_path)
-
- print(f"Downloading midas model weights for {model_type} to {path}")
- request.urlretrieve(midas_urls[model_type], path)
- print(f"{model_type} downloaded")
-
- return midas.api.load_model_inner(model_type)
-
- midas.api.load_model = load_model_wrapper
-
-
-def repair_config(sd_config):
-
- if not hasattr(sd_config.model.params, "use_ema"):
- sd_config.model.params.use_ema = False
-
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
-
-
-sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
-sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
-
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
- from modules import lowvram, sd_hijack
- checkpoint_info = checkpoint_info or select_checkpoint()
-
- if shared.sd_model:
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
- shared.sd_model = None
- gc.collect()
- devices.torch_gc()
-
- do_inpainting_hijack()
-
- timer = Timer()
-
- if already_loaded_state_dict is not None:
- state_dict = already_loaded_state_dict
- else:
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
-
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
-
- timer.record("find config")
-
- sd_config = OmegaConf.load(checkpoint_config)
- repair_config(sd_config)
-
- timer.record("load config")
-
- print(f"Creating model from config: {checkpoint_config}")
-
- sd_model = None
- try:
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
- pass
-
- if sd_model is None:
- print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
-
- sd_model.used_config = checkpoint_config
-
- timer.record("create model")
-
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
- else:
- sd_model.to(shared.device)
-
- timer.record("move model to device")
-
- sd_hijack.model_hijack.hijack(sd_model)
-
- timer.record("hijack")
-
- sd_model.eval()
- shared.sd_model = sd_model
-
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
-
- timer.record("load textual inversion embeddings")
-
- script_callbacks.model_loaded_callback(sd_model)
-
- timer.record("scripts callbacks")
-
- print(f"Model loaded in {timer.summary()}.")
-
- return sd_model
-
-
-def reload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
- checkpoint_info = info or select_checkpoint()
-
- if not sd_model:
- sd_model = shared.sd_model
-
- if sd_model is None: # previous model load failed
- current_checkpoint_info = None
- else:
- current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
- return
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
-
- sd_hijack.model_hijack.undo_hijack(sd_model)
-
- timer = Timer()
-
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
-
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
-
- timer.record("find config")
-
- if sd_model is None or checkpoint_config != sd_model.used_config:
- del sd_model
- checkpoints_loaded.clear()
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
- return shared.sd_model
-
- try:
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
- print("Failed to load checkpoint, restoring previous")
- load_model_weights(sd_model, current_checkpoint_info, None, timer)
- raise
- finally:
- sd_hijack.model_hijack.hijack(sd_model)
- timer.record("hijack")
-
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("script callbacks")
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
- timer.record("move model to device")
-
- print(f"Weights loaded in {timer.summary()}.")
-
- return sd_model
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
deleted file mode 100644
index 91c21700417f615e8d14e96c4f7a3e89368b6381..0000000000000000000000000000000000000000
--- a/modules/sd_models_config.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import re
-import os
-
-import torch
-
-from modules import shared, paths, sd_disable_initialization
-
-sd_configs_path = shared.sd_configs_path
-sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
-
-
-config_default = shared.sd_default_config
-config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
-config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
-config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
-config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
-config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
-config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
-config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
-
-def is_using_v_parameterization_for_sd2(state_dict):
- """
- Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
- """
-
- import ldm.modules.diffusionmodules.openaimodel
- from modules import devices
-
- device = devices.cpu
-
- with sd_disable_initialization.DisableInitialization():
- unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
- use_checkpoint=True,
- use_fp16=False,
- image_size=32,
- in_channels=4,
- out_channels=4,
- model_channels=320,
- attention_resolutions=[4, 2, 1],
- num_res_blocks=2,
- channel_mult=[1, 2, 4, 4],
- num_head_channels=64,
- use_spatial_transformer=True,
- use_linear_in_transformer=True,
- transformer_depth=1,
- context_dim=1024,
- legacy=False
- )
- unet.eval()
-
- with torch.no_grad():
- unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
- unet.load_state_dict(unet_sd, strict=True)
- unet.to(device=device, dtype=torch.float)
-
- test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
- x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
-
- out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
-
- return out < -1
-
-
-def guess_model_config_from_state_dict(sd, filename):
- sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
- diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
-
- if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
- return config_depth_model
-
- if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
- if diffusion_model_input.shape[1] == 9:
- return config_sd2_inpainting
- elif is_using_v_parameterization_for_sd2(sd):
- return config_sd2v
- else:
- return config_sd2
-
- if diffusion_model_input is not None:
- if diffusion_model_input.shape[1] == 9:
- return config_inpainting
- if diffusion_model_input.shape[1] == 8:
- return config_instruct_pix2pix
-
- if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
- return config_alt_diffusion
-
- return config_default
-
-
-def find_checkpoint_config(state_dict, info):
- if info is None:
- return guess_model_config_from_state_dict(state_dict, "")
-
- config = find_checkpoint_config_near_filename(info)
- if config is not None:
- return config
-
- return guess_model_config_from_state_dict(state_dict, info.filename)
-
-
-def find_checkpoint_config_near_filename(info):
- if info is None:
- return None
-
- config = os.path.splitext(info.filename)[0] + ".yaml"
- if os.path.exists(config):
- return config
-
- return None
-
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
deleted file mode 100644
index 28c2136fe73ac2c093a02cbdee54dd0afd024d21..0000000000000000000000000000000000000000
--- a/modules/sd_samplers.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
-
-# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
-
-all_samplers = [
- *sd_samplers_kdiffusion.samplers_data_k_diffusion,
- *sd_samplers_compvis.samplers_data_compvis,
-]
-all_samplers_map = {x.name: x for x in all_samplers}
-
-samplers = []
-samplers_for_img2img = []
-samplers_map = {}
-
-
-def create_sampler(name, model):
- if name is not None:
- config = all_samplers_map.get(name, None)
- else:
- config = all_samplers[0]
-
- assert config is not None, f'bad sampler name: {name}'
-
- sampler = config.constructor(model)
- sampler.config = config
-
- return sampler
-
-
-def set_samplers():
- global samplers, samplers_for_img2img
-
- hidden = set(shared.opts.hide_samplers)
- hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
-
- samplers = [x for x in all_samplers if x.name not in hidden]
- samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
-
- samplers_map.clear()
- for sampler in all_samplers:
- samplers_map[sampler.name.lower()] = sampler.name
- for alias in sampler.aliases:
- samplers_map[alias.lower()] = sampler.name
-
-
-set_samplers()
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
deleted file mode 100644
index a1aac7cf0aaf25375dcbc2b1ac0f5486ee683a6c..0000000000000000000000000000000000000000
--- a/modules/sd_samplers_common.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from collections import namedtuple
-import numpy as np
-import torch
-from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
-
-from modules.shared import opts, state
-import modules.shared as shared
-
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
-
-
-def setup_img2img_steps(p, steps=None):
- if opts.img2img_fix_steps or steps is not None:
- requested_steps = (steps or p.steps)
- steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = requested_steps - 1
- else:
- steps = p.steps
- t_enc = int(min(p.denoising_strength, 0.999) * steps)
-
- return steps, t_enc
-
-
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
-
-
-def single_sample_to_image(sample, approximation=None):
- if approximation is None:
- approximation = approximation_indexes.get(opts.show_progress_type, 0)
-
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
- else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
-
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- return Image.fromarray(x_sample)
-
-
-def sample_to_image(samples, index=0, approximation=None):
- return single_sample_to_image(samples[index], approximation)
-
-
-def samples_to_image_grid(samples, approximation=None):
- return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
-
-
-def store_latent(decoded):
- state.current_latent = decoded
-
- if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
- if not shared.parallel_processing_allowed:
- shared.state.assign_current_image(sample_to_image(decoded))
-
-
-class InterruptedException(BaseException):
- pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
deleted file mode 100644
index d03131cd49c27374e84e854777730541ed71061a..0000000000000000000000000000000000000000
--- a/modules/sd_samplers_compvis.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import math
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-
-import numpy as np
-import torch
-
-from modules.shared import state
-from modules import sd_samplers_common, prompt_parser, shared
-
-
-samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
-]
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p):
- return 0
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- if isinstance(cond, dict):
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x_dec = img_orig * self.mask + self.nmask * x_dec
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
-
- if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
- else:
- self.last_latent = res[1]
-
- sd_samplers_common.store_latent(self.last_latent)
-
- self.step += 1
- state.sampling_step = self.step
- shared.total_tqdm.update()
-
- return res
-
- def initialize(self, p):
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
- if self.eta != 0.0:
- p.extra_generation_params["Eta DDIM"] = self.eta
-
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
- valid_step = 999 / (1000 // num_steps)
- if valid_step == math.floor(valid_step):
- return int(valid_step) + 1
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
deleted file mode 100644
index 528f513fe5a7dc3227ef238e7c3a5091b2f868b6..0000000000000000000000000000000000000000
--- a/modules/sd_samplers_kdiffusion.py
+++ /dev/null
@@ -1,357 +0,0 @@
-from collections import deque
-import torch
-import inspect
-import einops
-import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common
-
-from modules.shared import opts, state
-import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
-
-samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
- ('Euler', 'sample_euler', ['k_euler'], {}),
- ('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
- ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
- ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
- ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
-]
-
-samplers_data_k_diffusion = [
- sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
- for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_diffusion.sampling, funcname)
-]
-
-sampler_extra_params = {
- 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
-}
-
-
-class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
-
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.step = 0
- self.image_cfg_scale = None
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
-
- return denoised
-
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
-
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
-
- if tensor.shape[1] == uncond.shape[1]:
- if not is_edit_model:
- cond_in = torch.cat([tensor, uncond])
- else:
- cond_in = torch.cat([tensor, uncond, uncond])
-
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
-
- if not is_edit_model:
- c_crossattn = [tensor[a:b]]
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
-
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
-
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
-
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
- cfg_denoised_callback(denoised_params)
-
- devices.test_for_nans(x_out, "unet")
-
- if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
- elif opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
-
- if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- else:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
-
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- self.step += 1
-
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- if x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
-
-
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
- denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-
- self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- sd_samplers_common.store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise sd_samplers_common.InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except sd_samplers_common.InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p):
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
- self.eta = p.eta if p.eta is not None else opts.eta_ancestral
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
- p.extra_generation_params["Eta"] = self.eta
-
- extra_params_kwargs['eta'] = self.eta
-
- return extra_params_kwargs
-
- def get_sigmas(self, p, steps):
- discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
- if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
- discard_next_to_last_sigma = True
- p.extra_generation_params["Discard penultimate sigma"] = True
-
- steps += 1 if discard_next_to_last_sigma else 0
-
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
-
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
-
- if discard_next_to_last_sigma:
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
-
- return sigmas
-
- def create_noise_sampler(self, x, sigmas, p):
- """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
- if shared.opts.no_dpmpp_sde_batch_determinism:
- return None
-
- from k_diffusion.sampling import BrownianTreeNoiseSampler
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
- current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
- return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
-
- sigmas = self.get_sigmas(p, steps)
-
- sigma_sched = sigmas[steps - t_enc - 1:]
- xi = x + noise * sigma_sched[0]
-
- extra_params_kwargs = self.initialize(p)
- parameters = inspect.signature(self.func).parameters
-
- if 'sigma_min' in parameters:
- ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
- extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in parameters:
- extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in parameters:
- extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in parameters:
- extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in parameters:
- extra_params_kwargs['sigmas'] = sigma_sched
-
- if self.funcname == 'sample_dpmpp_sde':
- noise_sampler = self.create_noise_sampler(x, sigmas, p)
- extra_params_kwargs['noise_sampler'] = noise_sampler
-
- self.model_wrap_cfg.init_latent = x
- self.last_latent = x
- extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale,
- }
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps = steps or p.steps
-
- sigmas = self.get_sigmas(p, steps)
-
- x = x * sigmas[0]
-
- extra_params_kwargs = self.initialize(p)
- parameters = inspect.signature(self.func).parameters
-
- if 'sigma_min' in parameters:
- extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
- extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in parameters:
- extra_params_kwargs['n'] = steps
- else:
- extra_params_kwargs['sigmas'] = sigmas
-
- if self.funcname == 'sample_dpmpp_sde':
- noise_sampler = self.create_noise_sampler(x, sigmas, p)
- extra_params_kwargs['noise_sampler'] = noise_sampler
-
- self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
deleted file mode 100644
index 9b00f76e9c62c794b3a27b36bae0f168ff4f5ab8..0000000000000000000000000000000000000000
--- a/modules/sd_vae.py
+++ /dev/null
@@ -1,216 +0,0 @@
-import torch
-import safetensors.torch
-import os
-import collections
-from collections import namedtuple
-from modules import paths, shared, devices, script_callbacks, sd_models
-import glob
-from copy import deepcopy
-
-
-vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
-vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-vae_dict = {}
-
-
-base_vae = None
-loaded_vae_file = None
-checkpoint_info = None
-
-checkpoints_loaded = collections.OrderedDict()
-
-def get_base_vae(model):
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
- return base_vae
- return None
-
-
-def store_base_vae(model):
- global base_vae, checkpoint_info
- if checkpoint_info != model.sd_checkpoint_info:
- assert not loaded_vae_file, "Trying to store non-base VAE!"
- base_vae = deepcopy(model.first_stage_model.state_dict())
- checkpoint_info = model.sd_checkpoint_info
-
-
-def delete_base_vae():
- global base_vae, checkpoint_info
- base_vae = None
- checkpoint_info = None
-
-
-def restore_base_vae(model):
- global loaded_vae_file
- if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
- print("Restoring base VAE")
- _load_vae_dict(model, base_vae)
- loaded_vae_file = None
- delete_base_vae()
-
-
-def get_filename(filepath):
- return os.path.basename(filepath)
-
-
-def refresh_vae_list():
- vae_dict.clear()
-
- paths = [
- os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
- os.path.join(sd_models.model_path, '**/*.vae.pt'),
- os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
- os.path.join(vae_path, '**/*.ckpt'),
- os.path.join(vae_path, '**/*.pt'),
- os.path.join(vae_path, '**/*.safetensors'),
- ]
-
- if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
- paths += [
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
- os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
- ]
-
- if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
- paths += [
- os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
- os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
- os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
- ]
-
- candidates = []
- for path in paths:
- candidates += glob.iglob(path, recursive=True)
-
- for filepath in candidates:
- name = get_filename(filepath)
- vae_dict[name] = filepath
-
-
-def find_vae_near_checkpoint(checkpoint_file):
- checkpoint_path = os.path.splitext(checkpoint_file)[0]
- for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
- if os.path.isfile(vae_location):
- return vae_location
-
- return None
-
-
-def resolve_vae(checkpoint_file):
- if shared.cmd_opts.vae_path is not None:
- return shared.cmd_opts.vae_path, 'from commandline argument'
-
- is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
-
- vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
- if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
- return vae_near_checkpoint, 'found near the checkpoint'
-
- if shared.opts.sd_vae == "None":
- return None, None
-
- vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
- if vae_from_options is not None:
- return vae_from_options, 'specified in settings'
-
- if not is_automatic:
- print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
-
- return None, None
-
-
-def load_vae_dict(filename, map_location):
- vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- return vae_dict_1
-
-
-def load_vae(model, vae_file=None, vae_source="from unknown source"):
- global vae_dict, loaded_vae_file
- # save_settings = False
-
- cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
-
- if vae_file:
- if cache_enabled and vae_file in checkpoints_loaded:
- # use vae checkpoint cache
- print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
- store_base_vae(model)
- _load_vae_dict(model, checkpoints_loaded[vae_file])
- else:
- assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
- print(f"Loading VAE weights {vae_source}: {vae_file}")
- store_base_vae(model)
-
- vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
- _load_vae_dict(model, vae_dict_1)
-
- if cache_enabled:
- # cache newly loaded vae
- checkpoints_loaded[vae_file] = vae_dict_1.copy()
-
- # clean up cache if limit is reached
- if cache_enabled:
- while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
- checkpoints_loaded.popitem(last=False) # LRU
-
- # If vae used is not in dict, update it
- # It will be removed on refresh though
- vae_opt = get_filename(vae_file)
- if vae_opt not in vae_dict:
- vae_dict[vae_opt] = vae_file
-
- elif loaded_vae_file:
- restore_base_vae(model)
-
- loaded_vae_file = vae_file
-
-
-# don't call this from outside
-def _load_vae_dict(model, vae_dict_1):
- model.first_stage_model.load_state_dict(vae_dict_1)
- model.first_stage_model.to(devices.dtype_vae)
-
-
-def clear_loaded_vae():
- global loaded_vae_file
- loaded_vae_file = None
-
-
-unspecified = object()
-
-
-def reload_vae_weights(sd_model=None, vae_file=unspecified):
- from modules import lowvram, devices, sd_hijack
-
- if not sd_model:
- sd_model = shared.sd_model
-
- checkpoint_info = sd_model.sd_checkpoint_info
- checkpoint_file = checkpoint_info.filename
-
- if vae_file == unspecified:
- vae_file, vae_source = resolve_vae(checkpoint_file)
- else:
- vae_source = "from function argument"
-
- if loaded_vae_file == vae_file:
- return
-
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
-
- sd_hijack.model_hijack.undo_hijack(sd_model)
-
- load_vae(sd_model, vae_file, vae_source)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
-
- print("VAE weights loaded.")
- return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
deleted file mode 100644
index 0027343a74500502b79bccd8d4b885ddcd7db568..0000000000000000000000000000000000000000
--- a/modules/sd_vae_approx.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import os
-
-import torch
-from torch import nn
-from modules import devices, paths
-
-sd_vae_approx_model = None
-
-
-class VAEApprox(nn.Module):
- def __init__(self):
- super(VAEApprox, self).__init__()
- self.conv1 = nn.Conv2d(4, 8, (7, 7))
- self.conv2 = nn.Conv2d(8, 16, (5, 5))
- self.conv3 = nn.Conv2d(16, 32, (3, 3))
- self.conv4 = nn.Conv2d(32, 64, (3, 3))
- self.conv5 = nn.Conv2d(64, 32, (3, 3))
- self.conv6 = nn.Conv2d(32, 16, (3, 3))
- self.conv7 = nn.Conv2d(16, 8, (3, 3))
- self.conv8 = nn.Conv2d(8, 3, (3, 3))
-
- def forward(self, x):
- extra = 11
- x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
- x = nn.functional.pad(x, (extra, extra, extra, extra))
-
- for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
- x = layer(x)
- x = nn.functional.leaky_relu(x, 0.1)
-
- return x
-
-
-def model():
- global sd_vae_approx_model
-
- if sd_vae_approx_model is None:
- sd_vae_approx_model = VAEApprox()
- sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
- sd_vae_approx_model.eval()
- sd_vae_approx_model.to(devices.device, devices.dtype)
-
- return sd_vae_approx_model
-
-
-def cheap_approximation(sample):
- # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
-
- coefs = torch.tensor([
- [0.298, 0.207, 0.208],
- [0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473],
- ]).to(sample.device)
-
- x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
-
- return x_sample
diff --git a/modules/shared.py b/modules/shared.py
deleted file mode 100644
index 805f9cc19cf9e529f19f1f94449a454b5050ec52..0000000000000000000000000000000000000000
--- a/modules/shared.py
+++ /dev/null
@@ -1,720 +0,0 @@
-import argparse
-import datetime
-import json
-import os
-import sys
-import time
-
-from PIL import Image
-import gradio as gr
-import tqdm
-
-import modules.interrogate
-import modules.memmon
-import modules.styles
-import modules.devices as devices
-from modules import localization, extensions, script_loading, errors, ui_components, shared_items
-from modules.paths import models_path, script_path, data_path
-
-
-demo = None
-
-sd_configs_path = os.path.join(script_path, "configs")
-sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
-sd_model_file = os.path.join(script_path, 'model.ckpt')
-default_sd_model_file = sd_model_file
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
-parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
-parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
-parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
-parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
-parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
-parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
-parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
-parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
-parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
-parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
-parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
-parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
-parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
-parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
-parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
-parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
-parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
-parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
-parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
-parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
-parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
-parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
-parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
-parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
-parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
-parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
-parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
-parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
-parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
-parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
-parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
-parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
-parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
-parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
-parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
-parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
-parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
-parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
-parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
-parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
-parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
-parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
-parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
-parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
-parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
-parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
-parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
-parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
-parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
-parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
-parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
-parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
-parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
-parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
-parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
-parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
-parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
-parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
-parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
-parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
-parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
-parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
-parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
-parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
-parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
-parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
-parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
-parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
-parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
-parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
-parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
-parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
-parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
-
-
-script_loading.preload_extensions(extensions.extensions_dir, parser)
-script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
-
-cmd_opts = parser.parse_args()
-
-restricted_opts = {
- "samples_filename_pattern",
- "directories_filename_pattern",
- "outdir_samples",
- "outdir_txt2img_samples",
- "outdir_img2img_samples",
- "outdir_extras_samples",
- "outdir_grids",
- "outdir_txt2img_grids",
- "outdir_save",
-}
-
-ui_reorder_categories = [
- "inpaint",
- "sampler",
- "checkboxes",
- "hires_fix",
- "dimensions",
- "cfg",
- "seed",
- "batch",
- "override_settings",
- "scripts",
-]
-
-cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
-
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
- (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
-
-device = devices.device
-weight_load_location = None if cmd_opts.lowram else "cpu"
-
-batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
-parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
-xformers_available = False
-config_filename = cmd_opts.ui_settings_file
-
-os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
-hypernetworks = {}
-loaded_hypernetworks = []
-
-
-def reload_hypernetworks():
- from modules.hypernetworks import hypernetwork
- global hypernetworks
-
- hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
-
-
-class State:
- skipped = False
- interrupted = False
- job = ""
- job_no = 0
- job_count = 0
- processing_has_refined_job_count = False
- job_timestamp = '0'
- sampling_step = 0
- sampling_steps = 0
- current_latent = None
- current_image = None
- current_image_sampling_step = 0
- id_live_preview = 0
- textinfo = None
- time_start = None
- need_restart = False
- server_start = None
-
- def skip(self):
- self.skipped = True
-
- def interrupt(self):
- self.interrupted = True
-
- def nextjob(self):
- if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
- self.do_set_current_image()
-
- self.job_no += 1
- self.sampling_step = 0
- self.current_image_sampling_step = 0
-
- def dict(self):
- obj = {
- "skipped": self.skipped,
- "interrupted": self.interrupted,
- "job": self.job,
- "job_count": self.job_count,
- "job_timestamp": self.job_timestamp,
- "job_no": self.job_no,
- "sampling_step": self.sampling_step,
- "sampling_steps": self.sampling_steps,
- }
-
- return obj
-
- def begin(self):
- self.sampling_step = 0
- self.job_count = -1
- self.processing_has_refined_job_count = False
- self.job_no = 0
- self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
- self.current_latent = None
- self.current_image = None
- self.current_image_sampling_step = 0
- self.id_live_preview = 0
- self.skipped = False
- self.interrupted = False
- self.textinfo = None
- self.time_start = time.time()
-
- devices.torch_gc()
-
- def end(self):
- self.job = ""
- self.job_count = 0
-
- devices.torch_gc()
-
- def set_current_image(self):
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
- if not parallel_processing_allowed:
- return
-
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
- self.do_set_current_image()
-
- def do_set_current_image(self):
- if self.current_latent is None:
- return
-
- import modules.sd_samplers
- if opts.show_progress_grid:
- self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
- else:
- self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
-
- self.current_image_sampling_step = self.sampling_step
-
- def assign_current_image(self, image):
- self.current_image = image
- self.id_live_preview += 1
-
-
-state = State()
-state.server_start = time.time()
-
-styles_filename = cmd_opts.styles_file
-prompt_styles = modules.styles.StyleDatabase(styles_filename)
-
-interrogator = modules.interrogate.InterrogateModels("interrogate")
-
-face_restorers = []
-
-class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
- self.default = default
- self.label = label
- self.component = component
- self.component_args = component_args
- self.onchange = onchange
- self.section = section
- self.refresh = refresh
-
-
-def options_section(section_identifier, options_dict):
- for k, v in options_dict.items():
- v.section = section_identifier
-
- return options_dict
-
-
-def list_checkpoint_tiles():
- import modules.sd_models
- return modules.sd_models.checkpoint_tiles()
-
-
-def refresh_checkpoints():
- import modules.sd_models
- return modules.sd_models.list_models()
-
-
-def list_samplers():
- import modules.sd_samplers
- return modules.sd_samplers.all_samplers
-
-
-hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
-
-options_templates = {}
-
-options_templates.update(options_section(('saving-images', "Saving images/grids"), {
- "samples_save": OptionInfo(True, "Always save all generated images"),
- "samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
- "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
-
- "grid_save": OptionInfo(True, "Always save all generated image grids"),
- "grid_format": OptionInfo('png', 'File format for grids'),
- "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
- "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
- "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
- "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
-
- "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
- "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
- "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
- "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
- "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
- "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
- "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
-
- "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
- "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
- "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
- "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
-
- "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
- "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
-
-}))
-
-options_templates.update(options_section(('saving-paths', "Paths for saving"), {
- "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
- "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
-}))
-
-options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
- "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
- "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
- "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
-}))
-
-options_templates.update(options_section(('upscaling', "Upscaling"), {
- "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
- "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
-}))
-
-options_templates.update(options_section(('face-restoration', "Face restoration"), {
- "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
- "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
-}))
-
-options_templates.update(options_section(('system', "System"), {
- "show_warnings": OptionInfo(False, "Show warnings in console."),
- "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
- "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
- "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
- "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
-}))
-
-options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
- "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
- "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
- "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
- "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
- "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
- "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
- "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
- "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
- "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
-}))
-
-options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
- "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
- "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
- "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
- "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
- "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
- "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
- "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
- "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
-}))
-
-options_templates.update(options_section(('compatibility', "Compatibility"), {
- "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
- "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
- "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
- "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
-}))
-
-options_templates.update(options_section(('interrogate', "Interrogate Options"), {
- "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
- "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
- "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
- "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
- "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
- "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
- "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
-}))
-
-options_templates.update(options_section(('extra_networks', "Extra Networks"), {
- "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
-}))
-
-options_templates.update(options_section(('ui', "User interface"), {
- "return_grid": OptionInfo(True, "Show grid in results for web"),
- "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
- "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
- "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
- "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "font": OptionInfo("", "Font for image grids that have text"),
- "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
- "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
- "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
- "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
- "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
- "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
-}))
-
-options_templates.update(options_section(('ui', "Live previews"), {
- "show_progressbar": OptionInfo(True, "Show progressbar"),
- "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
- "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
- "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
- "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
- "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
-}))
-
-options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
- "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
- 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
- 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
-}))
-
-options_templates.update(options_section(('postprocessing', "Postprocessing"), {
- 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
-}))
-
-options_templates.update(options_section((None, "Hidden options"), {
- "disabled_extensions": OptionInfo([], "Disable those extensions"),
- "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
-}))
-
-options_templates.update()
-
-
-class Options:
- data = None
- data_labels = options_templates
- typemap = {int: float}
-
- def __init__(self):
- self.data = {k: v.default for k, v in self.data_labels.items()}
-
- def __setattr__(self, key, value):
- if self.data is not None:
- if key in self.data or key in self.data_labels:
- assert not cmd_opts.freeze_settings, "changing settings is disabled"
-
- info = opts.data_labels.get(key, None)
- comp_args = info.component_args if info else None
- if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
-
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
-
- self.data[key] = value
- return
-
- return super(Options, self).__setattr__(key, value)
-
- def __getattr__(self, item):
- if self.data is not None:
- if item in self.data:
- return self.data[item]
-
- if item in self.data_labels:
- return self.data_labels[item].default
-
- return super(Options, self).__getattribute__(item)
-
- def set(self, key, value):
- """sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
-
- oldval = self.data.get(key, None)
- if oldval == value:
- return False
-
- try:
- setattr(self, key, value)
- except RuntimeError:
- return False
-
- if self.data_labels[key].onchange is not None:
- try:
- self.data_labels[key].onchange()
- except Exception as e:
- errors.display(e, f"changing setting {key} to {value}")
- setattr(self, key, oldval)
- return False
-
- return True
-
- def save(self, filename):
- assert not cmd_opts.freeze_settings, "saving settings is disabled"
-
- with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file, indent=4)
-
- def same_type(self, x, y):
- if x is None or y is None:
- return True
-
- type_x = self.typemap.get(type(x), type(x))
- type_y = self.typemap.get(type(y), type(y))
-
- return type_x == type_y
-
- def load(self, filename):
- with open(filename, "r", encoding="utf8") as file:
- self.data = json.load(file)
-
- bad_settings = 0
- for k, v in self.data.items():
- info = self.data_labels.get(k, None)
- if info is not None and not self.same_type(info.default, v):
- print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
- bad_settings += 1
-
- if bad_settings > 0:
- print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
-
- def onchange(self, key, func, call=True):
- item = self.data_labels.get(key)
- item.onchange = func
-
- if call:
- func()
-
- def dumpjson(self):
- d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
- return json.dumps(d)
-
- def add_option(self, key, info):
- self.data_labels[key] = info
-
- def reorder(self):
- """reorder settings so that all items related to section always go together"""
-
- section_ids = {}
- settings_items = self.data_labels.items()
- for k, item in settings_items:
- if item.section not in section_ids:
- section_ids[item.section] = len(section_ids)
-
- self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
-
- def cast_value(self, key, value):
- """casts an arbitrary to the same type as this setting's value with key
- Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
- """
-
- if value is None:
- return None
-
- default_value = self.data_labels[key].default
- if default_value is None:
- default_value = getattr(self, key, None)
- if default_value is None:
- return None
-
- expected_type = type(default_value)
- if expected_type == bool and value == "False":
- value = False
- else:
- value = expected_type(value)
-
- return value
-
-
-
-opts = Options()
-if os.path.exists(config_filename):
- opts.load(config_filename)
-
-settings_components = None
-"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
-
-latent_upscale_default_mode = "Latent"
-latent_upscale_modes = {
- "Latent": {"mode": "bilinear", "antialias": False},
- "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
- "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
- "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
- "Latent (nearest)": {"mode": "nearest", "antialias": False},
- "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
-}
-
-sd_upscalers = []
-
-sd_model = None
-
-clip_model = None
-
-progress_print_out = sys.stdout
-
-
-class TotalTQDM:
- def __init__(self):
- self._tqdm = None
-
- def reset(self):
- self._tqdm = tqdm.tqdm(
- desc="Total progress",
- total=state.job_count * state.sampling_steps,
- position=1,
- file=progress_print_out
- )
-
- def update(self):
- if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
- return
- if self._tqdm is None:
- self.reset()
- self._tqdm.update()
-
- def updateTotal(self, new_total):
- if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
- return
- if self._tqdm is None:
- self.reset()
- self._tqdm.total = new_total
-
- def clear(self):
- if self._tqdm is not None:
- self._tqdm.close()
- self._tqdm = None
-
-
-total_tqdm = TotalTQDM()
-
-mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
-mem_mon.start()
-
-
-def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
- return [file for file in filenames if os.path.isfile(file)]
-
-
-def html_path(filename):
- return os.path.join(script_path, "html", filename)
-
-
-def html(filename):
- path = html_path(filename)
-
- if os.path.exists(path):
- with open(path, encoding="utf8") as file:
- return file.read()
-
- return ""
diff --git a/modules/shared_items.py b/modules/shared_items.py
deleted file mode 100644
index e792a1349a2aaf763f3ede98479d0a2b5d92a454..0000000000000000000000000000000000000000
--- a/modules/shared_items.py
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-def realesrgan_models_names():
- import modules.realesrgan_model
- return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
-
-
-def postprocessing_scripts():
- import modules.scripts
-
- return modules.scripts.scripts_postproc.scripts
-
-
-def sd_vae_items():
- import modules.sd_vae
-
- return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
-
-
-def refresh_vae_list():
- import modules.sd_vae
-
- modules.sd_vae.refresh_vae_list()
diff --git a/modules/styles.py b/modules/styles.py
deleted file mode 100644
index 990d562369b49c4c5ce593d571a50640d9966301..0000000000000000000000000000000000000000
--- a/modules/styles.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-
-import csv
-import os
-import os.path
-import typing
-import collections.abc as abc
-import tempfile
-import shutil
-
-if typing.TYPE_CHECKING:
- # Only import this when code is being type-checked, it doesn't have any effect at runtime
- from .processing import StableDiffusionProcessing
-
-
-class PromptStyle(typing.NamedTuple):
- name: str
- prompt: str
- negative_prompt: str
-
-
-def merge_prompts(style_prompt: str, prompt: str) -> str:
- if "{prompt}" in style_prompt:
- res = style_prompt.replace("{prompt}", prompt)
- else:
- parts = filter(None, (prompt.strip(), style_prompt.strip()))
- res = ", ".join(parts)
-
- return res
-
-
-def apply_styles_to_prompt(prompt, styles):
- for style in styles:
- prompt = merge_prompts(style, prompt)
-
- return prompt
-
-
-class StyleDatabase:
- def __init__(self, path: str):
- self.no_style = PromptStyle("None", "", "")
- self.styles = {}
- self.path = path
-
- self.reload()
-
- def reload(self):
- self.styles.clear()
-
- if not os.path.exists(self.path):
- return
-
- with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
- reader = csv.DictReader(file)
- for row in reader:
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
-
- def get_style_prompts(self, styles):
- return [self.styles.get(x, self.no_style).prompt for x in styles]
-
- def get_negative_style_prompts(self, styles):
- return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
-
- def apply_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
-
- def apply_negative_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
-
- def save_styles(self, path: str) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
deleted file mode 100644
index 055953236e5e1d1401feca93ca6a1cc342cb8595..0000000000000000000000000000000000000000
--- a/modules/sub_quadratic_attention.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# original source:
-# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
-# license:
-# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
-# credit:
-# Amin Rezaei (original author)
-# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
-# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
-# implementation of:
-# Self-attention Does Not Need O(n2) Memory":
-# https://arxiv.org/abs/2112.05682v2
-
-from functools import partial
-import torch
-from torch import Tensor
-from torch.utils.checkpoint import checkpoint
-import math
-from typing import Optional, NamedTuple, List
-
-
-def narrow_trunc(
- input: Tensor,
- dim: int,
- start: int,
- length: int
-) -> Tensor:
- return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
-
-
-class AttnChunk(NamedTuple):
- exp_values: Tensor
- exp_weights_sum: Tensor
- max_score: Tensor
-
-
-class SummarizeChunk:
- @staticmethod
- def __call__(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- ) -> AttnChunk: ...
-
-
-class ComputeQueryChunkAttn:
- @staticmethod
- def __call__(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- ) -> Tensor: ...
-
-
-def _summarize_chunk(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- scale: float,
-) -> AttnChunk:
- attn_weights = torch.baddbmm(
- torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
- query,
- key.transpose(1,2),
- alpha=scale,
- beta=0,
- )
- max_score, _ = torch.max(attn_weights, -1, keepdim=True)
- max_score = max_score.detach()
- exp_weights = torch.exp(attn_weights - max_score)
- exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
- max_score = max_score.squeeze(-1)
- return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
-
-
-def _query_chunk_attention(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- summarize_chunk: SummarizeChunk,
- kv_chunk_size: int,
-) -> Tensor:
- batch_x_heads, k_tokens, k_channels_per_head = key.shape
- _, _, v_channels_per_head = value.shape
-
- def chunk_scanner(chunk_idx: int) -> AttnChunk:
- key_chunk = narrow_trunc(
- key,
- 1,
- chunk_idx,
- kv_chunk_size
- )
- value_chunk = narrow_trunc(
- value,
- 1,
- chunk_idx,
- kv_chunk_size
- )
- return summarize_chunk(query, key_chunk, value_chunk)
-
- chunks: List[AttnChunk] = [
- chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
- ]
- acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
- chunk_values, chunk_weights, chunk_max = acc_chunk
-
- global_max, _ = torch.max(chunk_max, 0, keepdim=True)
- max_diffs = torch.exp(chunk_max - global_max)
- chunk_values *= torch.unsqueeze(max_diffs, -1)
- chunk_weights *= max_diffs
-
- all_values = chunk_values.sum(dim=0)
- all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
- return all_values / all_weights
-
-
-# TODO: refactor CrossAttention#get_attention_scores to share code with this
-def _get_attention_scores_no_kv_chunking(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- scale: float,
-) -> Tensor:
- attn_scores = torch.baddbmm(
- torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
- query,
- key.transpose(1,2),
- alpha=scale,
- beta=0,
- )
- attn_probs = attn_scores.softmax(dim=-1)
- del attn_scores
- hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
- return hidden_states_slice
-
-
-class ScannedChunk(NamedTuple):
- chunk_idx: int
- attn_chunk: AttnChunk
-
-
-def efficient_dot_product_attention(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- query_chunk_size=1024,
- kv_chunk_size: Optional[int] = None,
- kv_chunk_size_min: Optional[int] = None,
- use_checkpoint=True,
-):
- """Computes efficient dot-product attention given query, key, and value.
- This is efficient version of attention presented in
- https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
- Args:
- query: queries for calculating attention with shape of
- `[batch * num_heads, tokens, channels_per_head]`.
- key: keys for calculating attention with shape of
- `[batch * num_heads, tokens, channels_per_head]`.
- value: values to be used in attention with shape of
- `[batch * num_heads, tokens, channels_per_head]`.
- query_chunk_size: int: query chunks size
- kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
- kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
- use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
- Returns:
- Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
- """
- batch_x_heads, q_tokens, q_channels_per_head = query.shape
- _, k_tokens, _ = key.shape
- scale = q_channels_per_head ** -0.5
-
- kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
- if kv_chunk_size_min is not None:
- kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
-
- def get_query_chunk(chunk_idx: int) -> Tensor:
- return narrow_trunc(
- query,
- 1,
- chunk_idx,
- min(query_chunk_size, q_tokens)
- )
-
- summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
- summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
- compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
- _get_attention_scores_no_kv_chunking,
- scale=scale
- ) if k_tokens <= kv_chunk_size else (
- # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
- partial(
- _query_chunk_attention,
- kv_chunk_size=kv_chunk_size,
- summarize_chunk=summarize_chunk,
- )
- )
-
- if q_tokens <= query_chunk_size:
- # fast-path for when there's just 1 query chunk
- return compute_query_chunk_attn(
- query=query,
- key=key,
- value=value,
- )
-
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
- query=get_query_chunk(i * query_chunk_size),
- key=key,
- value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
- return res
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
deleted file mode 100644
index 68e1103c514c5d2d75f23175126cfea0b3dcfca9..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/autocrop.py
+++ /dev/null
@@ -1,341 +0,0 @@
-import cv2
-import requests
-import os
-from collections import defaultdict
-from math import log, sqrt
-import numpy as np
-from PIL import Image, ImageDraw
-
-GREEN = "#0F0"
-BLUE = "#00F"
-RED = "#F00"
-
-
-def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
-
- scale_by = 1
- if is_landscape(im.width, im.height):
- scale_by = settings.crop_height / im.height
- elif is_portrait(im.width, im.height):
- scale_by = settings.crop_width / im.width
- elif is_square(im.width, im.height):
- if is_square(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_landscape(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_portrait(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_height / im.height
-
- im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
- im_debug = im.copy()
-
- focus = focal_point(im_debug, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- results = []
-
- results.append(im.crop(tuple(crop)))
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im_debug)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- results.append(im_debug)
- if settings.destop_view_image:
- im_debug.show()
-
- return results
-
-def focal_point(im, settings):
- corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
- entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
- face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
-
- pois = []
-
- weight_pref_total = 0
- if len(corner_points) > 0:
- weight_pref_total += settings.corner_points_weight
- if len(entropy_points) > 0:
- weight_pref_total += settings.entropy_points_weight
- if len(face_points) > 0:
- weight_pref_total += settings.face_points_weight
-
- corner_centroid = None
- if len(corner_points) > 0:
- corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
- pois.append(corner_centroid)
-
- entropy_centroid = None
- if len(entropy_points) > 0:
- entropy_centroid = centroid(entropy_points)
- entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
- pois.append(entropy_centroid)
-
- face_centroid = None
- if len(face_points) > 0:
- face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
- pois.append(face_centroid)
-
- average_point = poi_average(pois, settings)
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- max_size = min(im.width, im.height) * 0.07
- if corner_centroid is not None:
- color = BLUE
- box = corner_centroid.bounding(max_size * corner_centroid.weight)
- d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
- d.ellipse(box, outline=color)
- if len(corner_points) > 1:
- for f in corner_points:
- d.rectangle(f.bounding(4), outline=color)
- if entropy_centroid is not None:
- color = "#ff0"
- box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
- d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
- d.ellipse(box, outline=color)
- if len(entropy_points) > 1:
- for f in entropy_points:
- d.rectangle(f.bounding(4), outline=color)
- if face_centroid is not None:
- color = RED
- box = face_centroid.bounding(max_size * face_centroid.weight)
- d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
- d.ellipse(box, outline=color)
- if len(face_points) > 1:
- for f in face_points:
- d.rectangle(f.bounding(4), outline=color)
-
- d.ellipse(average_point.bounding(max_size), outline=GREEN)
-
- return average_point
-
-
-def image_face_points(im, settings):
- if settings.dnn_model_path is not None:
- detector = cv2.FaceDetectorYN.create(
- settings.dnn_model_path,
- "",
- (im.width, im.height),
- 0.9, # score threshold
- 0.3, # nms threshold
- 5000 # keep top k before nms
- )
- faces = detector.detect(np.array(im))
- results = []
- if faces[1] is not None:
- for face in faces[1]:
- x = face[0]
- y = face[1]
- w = face[2]
- h = face[3]
- results.append(
- PointOfInterest(
- int(x + (w * 0.5)), # face focus left/right is center
- int(y + (h * 0.33)), # face focus up/down is close to the top of the head
- size = w,
- weight = 1/len(faces[1])
- )
- )
- return results
- else:
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
- for t in tries:
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
- return []
-
-
-def image_corner_points(im, settings):
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.06,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
-
- return focal_points
-
-
-def image_entropy_points(im, settings):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
- else:
- return []
-
- e_max = 0
- crop_current = [0, 0, settings.crop_width, settings.crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + settings.crop_width/2)
- y_mid = int(crop_best[1] + settings.crop_height/2)
-
- return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
-
-
-def image_entropy(im):
- # greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
-def centroid(pois):
- x = [poi.x for poi in pois]
- y = [poi.y for poi in pois]
- return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
-
-
-def poi_average(pois, settings):
- weight = 0.0
- x = 0.0
- y = 0.0
- for poi in pois:
- weight += poi.weight
- x += poi.x * poi.weight
- y += poi.y * poi.weight
- avg_x = round(weight and x / weight)
- avg_y = round(weight and y / weight)
-
- return PointOfInterest(avg_x, avg_y)
-
-
-def is_landscape(w, h):
- return w > h
-
-
-def is_portrait(w, h):
- return h > w
-
-
-def is_square(w, h):
- return w == h
-
-
-def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
-
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
- f.write(response.content)
-
- if os.path.exists(cache_file):
- return cache_file
- return None
-
-
-class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
-
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
-
-
-class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = face_points_weight
- self.annotate_image = annotate_image
- self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
deleted file mode 100644
index af9fbcf288ccc7a5402ed1cae2da853022ab080d..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/dataset.py
+++ /dev/null
@@ -1,246 +0,0 @@
-import os
-import numpy as np
-import PIL
-import torch
-from PIL import Image
-from torch.utils.data import Dataset, DataLoader, Sampler
-from torchvision import transforms
-from collections import defaultdict
-from random import shuffle, choices
-
-import random
-import tqdm
-from modules import devices, shared
-import re
-
-from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
-
-re_numbers_at_start = re.compile(r"^[-\d]+\s*")
-
-
-class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
- self.filename = filename
- self.filename_text = filename_text
- self.weight = weight
- self.latent_dist = latent_dist
- self.latent_sample = latent_sample
- self.cond = cond
- self.cond_text = cond_text
- self.pixel_values = pixel_values
-
-
-class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
- self.placeholder_token = placeholder_token
-
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
-
- self.dataset = []
-
- with open(template_file, "r") as file:
- lines = [x.strip() for x in file.readlines()]
-
- self.lines = lines
-
- assert data_root, 'dataset directory not specified'
- assert os.path.isdir(data_root), "Dataset directory doesn't exist"
- assert os.listdir(data_root), "Dataset directory is empty"
-
- self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
- self.shuffle_tags = shuffle_tags
- self.tag_drop_out = tag_drop_out
- groups = defaultdict(list)
-
- print("Preparing dataset...")
- for path in tqdm.tqdm(self.image_paths):
- alpha_channel = None
- if shared.state.interrupted:
- raise Exception("interrupted")
- try:
- image = Image.open(path)
- #Currently does not work for single color transparency
- #We would need to read image.info['transparency'] for that
- if use_weight and 'A' in image.getbands():
- alpha_channel = image.getchannel('A')
- image = image.convert('RGB')
- if not varsize:
- image = image.resize((width, height), PIL.Image.BICUBIC)
- except Exception:
- continue
-
- text_filename = os.path.splitext(path)[0] + ".txt"
- filename = os.path.basename(path)
-
- if os.path.exists(text_filename):
- with open(text_filename, "r", encoding="utf8") as file:
- filename_text = file.read()
- else:
- filename_text = os.path.splitext(filename)[0]
- filename_text = re.sub(re_numbers_at_start, '', filename_text)
- if re_word:
- tokens = re_word.findall(filename_text)
- filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
-
- npimage = np.array(image).astype(np.uint8)
- npimage = (npimage / 127.5 - 1.0).astype(np.float32)
-
- torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
- latent_sample = None
-
- with devices.autocast():
- latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
-
- #Perform latent sampling, even for random sampling.
- #We need the sample dimensions for the weights
- if latent_sampling_method == "deterministic":
- if isinstance(latent_dist, DiagonalGaussianDistribution):
- # Works only for DiagonalGaussianDistribution
- latent_dist.std = 0
- else:
- latent_sampling_method = "once"
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
-
- if use_weight and alpha_channel is not None:
- channels, *latent_size = latent_sample.shape
- weight_img = alpha_channel.resize(latent_size)
- npweight = np.array(weight_img).astype(np.float32)
- #Repeat for every channel in the latent sample
- weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
- #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
- weight -= weight.min()
- weight /= weight.mean()
- elif use_weight:
- #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
- weight = torch.ones(latent_sample.shape)
- else:
- weight = None
-
- if latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
- else:
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
-
- if not (self.tag_drop_out != 0 or self.shuffle_tags):
- entry.cond_text = self.create_text(filename_text)
-
- if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
- with devices.autocast():
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
- groups[image.size].append(len(self.dataset))
- self.dataset.append(entry)
- del torchdata
- del latent_dist
- del latent_sample
- del weight
-
- self.length = len(self.dataset)
- self.groups = list(groups.values())
- assert self.length > 0, "No images have been found in the dataset."
- self.batch_size = min(batch_size, self.length)
- self.gradient_step = min(gradient_step, self.length // self.batch_size)
- self.latent_sampling_method = latent_sampling_method
-
- if len(groups) > 1:
- print("Buckets:")
- for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
- print(f" {w}x{h}: {len(ids)}")
- print()
-
- def create_text(self, filename_text):
- text = random.choice(self.lines)
- tags = filename_text.split(',')
- if self.tag_drop_out != 0:
- tags = [t for t in tags if random.random() > self.tag_drop_out]
- if self.shuffle_tags:
- random.shuffle(tags)
- text = text.replace("[filewords]", ','.join(tags))
- text = text.replace("[name]", self.placeholder_token)
- return text
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, i):
- entry = self.dataset[i]
- if self.tag_drop_out != 0 or self.shuffle_tags:
- entry.cond_text = self.create_text(entry.filename_text)
- if self.latent_sampling_method == "random":
- entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
- return entry
-
-
-class GroupedBatchSampler(Sampler):
- def __init__(self, data_source: PersonalizedBase, batch_size: int):
- super().__init__(data_source)
-
- n = len(data_source)
- self.groups = data_source.groups
- self.len = n_batch = n // batch_size
- expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
- self.base = [int(e) // batch_size for e in expected]
- self.n_rand_batches = nrb = n_batch - sum(self.base)
- self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
- self.batch_size = batch_size
-
- def __len__(self):
- return self.len
-
- def __iter__(self):
- b = self.batch_size
-
- for g in self.groups:
- shuffle(g)
-
- batches = []
- for g in self.groups:
- batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
- for _ in range(self.n_rand_batches):
- rand_group = choices(self.groups, self.probs)[0]
- batches.append(choices(rand_group, k=b))
-
- shuffle(batches)
-
- yield from batches
-
-
-class PersonalizedDataLoader(DataLoader):
- def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
- if latent_sampling_method == "random":
- self.collate_fn = collate_wrapper_random
- else:
- self.collate_fn = collate_wrapper
-
-
-class BatchLoader:
- def __init__(self, data):
- self.cond_text = [entry.cond_text for entry in data]
- self.cond = [entry.cond for entry in data]
- self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
- if all(entry.weight is not None for entry in data):
- self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
- else:
- self.weight = None
- #self.emb_index = [entry.emb_index for entry in data]
- #print(self.latent_sample.device)
-
- def pin_memory(self):
- self.latent_sample = self.latent_sample.pin_memory()
- return self
-
-def collate_wrapper(batch):
- return BatchLoader(batch)
-
-class BatchLoaderRandom(BatchLoader):
- def __init__(self, data):
- super().__init__(data)
-
- def pin_memory(self):
- return self
-
-def collate_wrapper_random(batch):
- return BatchLoaderRandom(batch)
\ No newline at end of file
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
deleted file mode 100644
index 5593f88c799c53a8529ae8fd12a84c2e92396b26..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/image_embedding.py
+++ /dev/null
@@ -1,220 +0,0 @@
-import base64
-import json
-import numpy as np
-import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
-from fonts.ttf import Roboto
-import torch
-from modules.shared import opts
-
-
-class EmbeddingEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, torch.Tensor):
- return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
- return json.JSONEncoder.default(self, obj)
-
-
-class EmbeddingDecoder(json.JSONDecoder):
- def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
-
- def object_hook(self, d):
- if 'TORCHTENSOR' in d:
- return torch.from_numpy(np.array(d['TORCHTENSOR']))
- return d
-
-
-def embedding_to_b64(data):
- d = json.dumps(data, cls=EmbeddingEncoder)
- return base64.b64encode(d.encode())
-
-
-def embedding_from_b64(data):
- d = base64.b64decode(data)
- return json.loads(d, cls=EmbeddingDecoder)
-
-
-def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
- while True:
- seed = (a * seed + c) % m
- yield seed % 255
-
-
-def xor_block(block):
- g = lcg()
- randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
- return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
-
-
-def style_block(block, sequence):
- im = Image.new('RGB', (block.shape[1], block.shape[0]))
- draw = ImageDraw.Draw(im)
- i = 0
- for x in range(-6, im.size[0], 8):
- for yi, y in enumerate(range(-6, im.size[1], 8)):
- offset = 0
- if yi % 2 == 0:
- offset = 4
- shade = sequence[i % len(sequence)]
- i += 1
- draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
-
- fg = np.array(im).astype(np.uint8) & 0xF0
-
- return block ^ fg
-
-
-def insert_image_data_embed(image, data):
- d = 3
- data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
- data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
- data_np_high = data_np_ >> 4
- data_np_low = data_np_ & 0x0F
-
- h = image.size[1]
- next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
- next_size = next_size + ((h*d)-(next_size % (h*d)))
-
- data_np_low = np.resize(data_np_low, next_size)
- data_np_low = data_np_low.reshape((h, -1, d))
-
- data_np_high = np.resize(data_np_high, next_size)
- data_np_high = data_np_high.reshape((h, -1, d))
-
- edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
- edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
-
- data_np_low = style_block(data_np_low, sequence=edge_style)
- data_np_low = xor_block(data_np_low)
- data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
- data_np_high = xor_block(data_np_high)
-
- im_low = Image.fromarray(data_np_low, mode='RGB')
- im_high = Image.fromarray(data_np_high, mode='RGB')
-
- background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
- background.paste(im_low, (0, 0))
- background.paste(image, (im_low.size[0]+1, 0))
- background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
-
- return background
-
-
-def crop_black(img, tol=0):
- mask = (img > tol).all(2)
- mask0, mask1 = mask.any(0), mask.any(1)
- col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
- row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
- return img[row_start:row_end, col_start:col_end]
-
-
-def extract_image_data_embed(image):
- d = 3
- outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
- black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
- if black_cols[0].shape[0] < 2:
- print('No Image data blocks found.')
- return None
-
- data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
- data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
-
- data_block_lower = xor_block(data_block_lower)
- data_block_upper = xor_block(data_block_upper)
-
- data_block = (data_block_upper << 4) | (data_block_lower)
- data_block = data_block.flatten().tobytes()
-
- data = zlib.decompress(data_block)
- return json.loads(data, cls=EmbeddingDecoder)
-
-
-def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
- from math import cos
-
- image = srcimage.copy()
- fontsize = 32
- if textfont is None:
- try:
- textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
- textfont = opts.font or Roboto
- except Exception:
- textfont = Roboto
-
- factor = 1.5
- gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
- for y in range(image.size[1]):
- mag = 1-cos(y/image.size[1]*factor)
- mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
- gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
- image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
-
- draw = ImageDraw.Draw(image)
-
- font = ImageFont.truetype(textfont, fontsize)
- padding = 10
-
- _, _, w, h = draw.textbbox((0, 0), title, font=font)
- fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
- font = ImageFont.truetype(textfont, fontsize)
- _, _, w, h = draw.textbbox((0, 0), title, font=font)
- draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
-
- _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
- fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
- fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
- fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
-
- font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
-
- draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
- draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
-
- return image
-
-
-if __name__ == '__main__':
-
- testEmbed = Image.open('test_embedding.png')
- data = extract_image_data_embed(testEmbed)
- assert data is not None
-
- data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
- assert data is not None
-
- image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
- cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
-
- test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
-
- embedded_image = insert_image_data_embed(cap_image, test_embed)
-
- retrived_embed = extract_image_data_embed(embedded_image)
-
- assert str(retrived_embed) == str(test_embed)
-
- embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
-
- assert embedded_image == embedded_image2
-
- g = lcg()
- shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
-
- reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
- 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
- 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
- 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
- 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
- 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
- 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
- 204, 86, 73, 222, 44, 198, 118, 240, 97]
-
- assert shared_random == reference_random
-
- hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
-
- assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
deleted file mode 100644
index f63fc72ff8d4269967dc0f92b61a278d068324b5..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/learn_schedule.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import tqdm
-
-
-class LearnScheduleIterator:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- """
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
- """
-
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- try:
- for i, pair in enumerate(pairs):
- if not pair.strip():
- continue
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- assert self.rates
- except (ValueError, AssertionError):
- raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
-
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
-
-
-class LearnRateScheduler:
- def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
- self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
- (self.learn_rate, self.end_step) = next(self.schedules)
- self.verbose = verbose
-
- if self.verbose:
- print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
-
- self.finished = False
-
- def step(self, step_number):
- if step_number < self.end_step:
- return False
-
- try:
- (self.learn_rate, self.end_step) = next(self.schedules)
- except StopIteration:
- self.finished = True
- return False
- return True
-
- def apply(self, optimizer, step_number):
- if not self.step(step_number):
- return
-
- if self.verbose:
- tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
-
- for pg in optimizer.param_groups:
- pg['lr'] = self.learn_rate
-
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
deleted file mode 100644
index 734a4b6f463d11685933771825b2792c8270e53a..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/logging.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import datetime
-import json
-import os
-
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
-saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
-
-
-def save_settings_to_file(log_directory, all_params):
- now = datetime.datetime.now()
- params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
-
- keys = saved_params_all
- if all_params.get('preview_from_txt2img'):
- keys = keys | saved_params_previews
-
- params.update({k: v for k, v in all_params.items() if k in keys})
-
- filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
- with open(os.path.join(log_directory, filename), "w") as file:
- json.dump(params, file, indent=4)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
deleted file mode 100644
index 2239cb842ed5d25f9bde19f10509220ad6c01c18..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/preprocess.py
+++ /dev/null
@@ -1,230 +0,0 @@
-import os
-from PIL import Image, ImageOps
-import math
-import platform
-import sys
-import tqdm
-import time
-
-from modules import paths, shared, images, deepbooru
-from modules.shared import opts, cmd_opts
-from modules.textual_inversion import autocrop
-
-
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
- try:
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.model.start()
-
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
-
- finally:
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.model.stop()
-
-
-def listfiles(dirname):
- return os.listdir(dirname)
-
-
-class PreprocessParams:
- src = None
- dstdir = None
- subindex = 0
- flip = False
- process_caption = False
- process_caption_deepbooru = False
- preprocess_txt_action = None
-
-
-def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
- caption = ""
-
- if params.process_caption:
- caption += shared.interrogator.generate_caption(image)
-
- if params.process_caption_deepbooru:
- if len(caption) > 0:
- caption += ", "
- caption += deepbooru.model.tag_multi(image)
-
- filename_part = params.src
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{params.subindex}-{filename_part}"
- image.save(os.path.join(params.dstdir, f"{basename}.png"))
-
- if params.preprocess_txt_action == 'prepend' and existing_caption:
- caption = existing_caption + ' ' + caption
- elif params.preprocess_txt_action == 'append' and existing_caption:
- caption = caption + ' ' + existing_caption
- elif params.preprocess_txt_action == 'copy' and existing_caption:
- caption = existing_caption
-
- caption = caption.strip()
-
- if len(caption) > 0:
- with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
- file.write(caption)
-
- params.subindex += 1
-
-
-def save_pic(image, index, params, existing_caption=None):
- save_pic_with_caption(image, index, params, existing_caption=existing_caption)
-
- if params.flip:
- save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
-
-
-def split_pic(image, inverse_xy, width, height, overlap_ratio):
- if inverse_xy:
- from_w, from_h = image.height, image.width
- to_w, to_h = height, width
- else:
- from_w, from_h = image.width, image.height
- to_w, to_h = width, height
- h = from_h * to_w // from_w
- if inverse_xy:
- image = image.resize((h, to_w))
- else:
- image = image.resize((to_w, h))
-
- split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
- y_step = (h - to_h) / (split_count - 1)
- for i in range(split_count):
- y = int(y_step * i)
- if inverse_xy:
- splitted = image.crop((y, 0, y + to_h, to_w))
- else:
- splitted = image.crop((0, y, to_w, y + to_h))
- yield splitted
-
-# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
-def center_crop(image: Image, w: int, h: int):
- iw, ih = image.size
- if ih / h < iw / w:
- sw = w * ih / h
- box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
- else:
- sh = h * iw / w
- box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
- return image.resize((w, h), Image.Resampling.LANCZOS, box)
-
-
-def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
- iw, ih = image.size
- err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
- wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
- if minarea <= w * h <= maxarea and err(w, h) <= threshold),
- key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
- default=None
- )
- return wh and center_crop(image, *wh)
-
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
- width = process_width
- height = process_height
- src = os.path.abspath(process_src)
- dst = os.path.abspath(process_dst)
- split_threshold = max(0.0, min(1.0, split_threshold))
- overlap_ratio = max(0.0, min(0.9, overlap_ratio))
-
- assert src != dst, 'same directory specified as source and destination'
-
- os.makedirs(dst, exist_ok=True)
-
- files = listfiles(src)
-
- shared.state.job = "preprocess"
- shared.state.textinfo = "Preprocessing..."
- shared.state.job_count = len(files)
-
- params = PreprocessParams()
- params.dstdir = dst
- params.flip = process_flip
- params.process_caption = process_caption
- params.process_caption_deepbooru = process_caption_deepbooru
- params.preprocess_txt_action = preprocess_txt_action
-
- pbar = tqdm.tqdm(files)
- for index, imagefile in enumerate(pbar):
- params.subindex = 0
- filename = os.path.join(src, imagefile)
- try:
- img = Image.open(filename).convert("RGB")
- except Exception:
- continue
-
- description = f"Preprocessing [Image {index}/{len(files)}]"
- pbar.set_description(description)
- shared.state.textinfo = description
-
- params.src = filename
-
- existing_caption = None
- existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
- if os.path.exists(existing_caption_filename):
- with open(existing_caption_filename, 'r', encoding="utf8") as file:
- existing_caption = file.read()
-
- if shared.state.interrupted:
- break
-
- if img.height > img.width:
- ratio = (img.width * height) / (img.height * width)
- inverse_xy = False
- else:
- ratio = (img.height * width) / (img.width * height)
- inverse_xy = True
-
- process_default_resize = True
-
- if process_split and ratio < 1.0 and ratio <= split_threshold:
- for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
- save_pic(splitted, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_focal_crop and img.height != img.width:
-
- dnn_model_path = None
- try:
- dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
- except Exception as e:
- print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
-
- autocrop_settings = autocrop.Settings(
- crop_width = width,
- crop_height = height,
- face_points_weight = process_focal_crop_face_weight,
- entropy_points_weight = process_focal_crop_entropy_weight,
- corner_points_weight = process_focal_crop_edges_weight,
- annotate_image = process_focal_crop_debug,
- dnn_model_path = dnn_model_path,
- )
- for focal in autocrop.crop_image(img, autocrop_settings):
- save_pic(focal, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_multicrop:
- cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
- if cropped is not None:
- save_pic(cropped, index, params, existing_caption=existing_caption)
- else:
- print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
- process_default_resize = False
-
- if process_default_resize:
- img = images.resize_image(1, img, width, height)
- save_pic(img, index, params, existing_caption=existing_caption)
-
- shared.state.nextjob()
diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png
deleted file mode 100644
index 07e2d9afaeaff3751b68a7c0f49d8b3466474282..0000000000000000000000000000000000000000
Binary files a/modules/textual_inversion/test_embedding.png and /dev/null differ
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
deleted file mode 100644
index c63c7d1dda0840fb337cedcd95d16e05e60935a7..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/textual_inversion.py
+++ /dev/null
@@ -1,657 +0,0 @@
-import os
-import sys
-import traceback
-import inspect
-from collections import namedtuple
-
-import torch
-import tqdm
-import html
-import datetime
-import csv
-import safetensors.torch
-
-import numpy as np
-from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
-
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
-import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnRateScheduler
-
-from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
-from modules.textual_inversion.logging import save_settings_to_file
-
-
-TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
-textual_inversion_templates = {}
-
-
-def list_textual_inversion_templates():
- textual_inversion_templates.clear()
-
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
- for fn in fns:
- path = os.path.join(root, fn)
-
- textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
-
- return textual_inversion_templates
-
-
-class Embedding:
- def __init__(self, vec, name, step=None):
- self.vec = vec
- self.name = name
- self.step = step
- self.shape = None
- self.vectors = 0
- self.cached_checksum = None
- self.sd_checkpoint = None
- self.sd_checkpoint_name = None
- self.optimizer_state_dict = None
- self.filename = None
-
- def save(self, filename):
- embedding_data = {
- "string_to_token": {"*": 265},
- "string_to_param": {"*": self.vec},
- "name": self.name,
- "step": self.step,
- "sd_checkpoint": self.sd_checkpoint,
- "sd_checkpoint_name": self.sd_checkpoint_name,
- }
-
- torch.save(embedding_data, filename)
-
- if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
- optimizer_saved_dict = {
- 'hash': self.checksum(),
- 'optimizer_state_dict': self.optimizer_state_dict,
- }
- torch.save(optimizer_saved_dict, filename + '.optim')
-
- def checksum(self):
- if self.cached_checksum is not None:
- return self.cached_checksum
-
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
-
- self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
- return self.cached_checksum
-
-
-class DirWithTextualInversionEmbeddings:
- def __init__(self, path):
- self.path = path
- self.mtime = None
-
- def has_changed(self):
- if not os.path.isdir(self.path):
- return False
-
- mt = os.path.getmtime(self.path)
- if self.mtime is None or mt > self.mtime:
- return True
-
- def update(self):
- if not os.path.isdir(self.path):
- return
-
- self.mtime = os.path.getmtime(self.path)
-
-
-class EmbeddingDatabase:
- def __init__(self):
- self.ids_lookup = {}
- self.word_embeddings = {}
- self.skipped_embeddings = {}
- self.expected_shape = -1
- self.embedding_dirs = {}
- self.previously_displayed_embeddings = ()
-
- def add_embedding_dir(self, path):
- self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
-
- def clear_embedding_dirs(self):
- self.embedding_dirs.clear()
-
- def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
-
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
- return embedding
-
- def get_expected_shape(self):
- vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
- return vec.shape[1]
-
- def load_from_file(self, path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- _, second_ext = os.path.splitext(name)
- if second_ext.upper() == '.PREVIEW':
- return
-
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- elif ext in ['.SAFETENSORS']:
- data = safetensors.torch.load_file(path, device="cpu")
- else:
- return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
- embedding.filename = path
-
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
- else:
- self.skipped_embeddings[name] = embedding
-
- def load_from_dir(self, embdir):
- if not os.path.isdir(embdir.path):
- return
-
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
- for fn in fns:
- try:
- fullfn = os.path.join(root, fn)
-
- if os.stat(fullfn).st_size == 0:
- continue
-
- self.load_from_file(fullfn, fn)
- except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
- def load_textual_inversion_embeddings(self, force_reload=False):
- if not force_reload:
- need_reload = False
- for path, embdir in self.embedding_dirs.items():
- if embdir.has_changed():
- need_reload = True
- break
-
- if not need_reload:
- return
-
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.expected_shape = self.get_expected_shape()
-
- for path, embdir in self.embedding_dirs.items():
- self.load_from_dir(embdir)
- embdir.update()
-
- displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
- if self.previously_displayed_embeddings != displayed_embeddings:
- self.previously_displayed_embeddings = displayed_embeddings
- print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
- if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
-
- def find_embedding_at_position(self, tokens, offset):
- token = tokens[offset]
- possible_matches = self.ids_lookup.get(token, None)
-
- if possible_matches is None:
- return None, None
-
- for ids, embedding in possible_matches:
- if tokens[offset:offset + len(ids)] == ids:
- return embedding, len(ids)
-
- return None, None
-
-
-def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
- cond_model = shared.sd_model.cond_stage_model
-
- with devices.autocast():
- cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
-
- #cond_model expects at least some text, so we provide '*' as backup.
- embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
- vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
-
- #Only copy if we provided an init_text, otherwise keep vectors as zeros
- if init_text:
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
-
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- if not overwrite_old:
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- embedding = Embedding(vec, name)
- embedding.step = 0
- embedding.save(fn)
-
- return fn
-
-
-def write_loss(log_directory, filename, step, epoch_len, values):
- if shared.opts.training_write_csv_every == 0:
- return
-
- if step % shared.opts.training_write_csv_every != 0:
- return
- write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
-
- with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
-
- if write_csv_header:
- csv_writer.writeheader()
-
- epoch = (step - 1) // epoch_len
- epoch_step = (step - 1) % epoch_len
-
- csv_writer.writerow({
- "step": step,
- "epoch": epoch,
- "epoch_step": epoch_step,
- **values,
- })
-
-def tensorboard_setup(log_directory):
- os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
- return SummaryWriter(
- log_dir=os.path.join(log_directory, "tensorboard"),
- flush_secs=shared.opts.training_tensorboard_flush_every)
-
-def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
- tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
- tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
- tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
- tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
-
-def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag,
- scalar_value=value, global_step=step)
-
-def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
- # Convert a pil image to a torch tensor
- img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
- len(pil_image.getbands()))
- img_tensor = img_tensor.permute((2, 0, 1))
-
- tensorboard_writer.add_image(tag, img_tensor, global_step=step)
-
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
- assert model_name, f"{name} not selected"
- assert learn_rate, "Learning rate is empty or 0"
- assert isinstance(batch_size, int), "Batch size must be integer"
- assert batch_size > 0, "Batch size must be positive"
- assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
- assert gradient_step > 0, "Gradient accumulation step must be positive"
- assert data_root, "Dataset directory is empty"
- assert os.path.isdir(data_root), "Dataset directory doesn't exist"
- assert os.listdir(data_root), "Dataset directory is empty"
- assert template_filename, "Prompt template file not selected"
- assert template_file, f"Prompt template file {template_filename} not found"
- assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
- assert steps, "Max steps is empty or 0"
- assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0, "Max steps must be positive"
- assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0, "Save {name} must be positive or 0"
- assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0, "Create image must be positive or 0"
- if save_model_every or create_image_every:
- assert log_directory, "Log directory is empty"
-
-
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- save_embedding_every = save_embedding_every or 0
- create_image_every = create_image_every or 0
- template_file = textual_inversion_templates.get(template_filename, None)
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
- template_file = template_file.path
-
- shared.state.job = "train-embedding"
- shared.state.textinfo = "Initializing textual inversion training..."
- shared.state.job_count = steps
-
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
-
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
- unload = shared.opts.unload_models_when_training
-
- if save_embedding_every > 0:
- embedding_dir = os.path.join(log_directory, "embeddings")
- os.makedirs(embedding_dir, exist_ok=True)
- else:
- embedding_dir = None
-
- if create_image_every > 0:
- images_dir = os.path.join(log_directory, "images")
- os.makedirs(images_dir, exist_ok=True)
- else:
- images_dir = None
-
- if create_image_every > 0 and save_image_with_stored_embedding:
- images_embeds_dir = os.path.join(log_directory, "image_embeddings")
- os.makedirs(images_embeds_dir, exist_ok=True)
- else:
- images_embeds_dir = None
-
- hijack = sd_hijack.model_hijack
-
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
- checkpoint = sd_models.select_checkpoint()
-
- initial_step = embedding.step or 0
- if initial_step >= steps:
- shared.state.textinfo = "Model has already been trained beyond specified max steps"
- return embedding, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
- torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
- None
- if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
- # dataset loading may take a while, so input validations and early returns should be done before this
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- old_parallel_processing_allowed = shared.parallel_processing_allowed
-
- if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
-
- pin_memory = shared.opts.pin_memory
-
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
-
- if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
-
- latent_sampling_method = ds.latent_sampling_method
-
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
-
- if unload:
- shared.parallel_processing_allowed = False
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
- if shared.opts.save_optimizer_state:
- optimizer_state_dict = None
- if os.path.exists(filename + '.optim'):
- optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
- if embedding.checksum() == optimizer_saved_dict.get('hash', None):
- optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
-
- if optimizer_state_dict is not None:
- optimizer.load_state_dict(optimizer_state_dict)
- print("Loaded existing optimizer from checkpoint")
- else:
- print("No saved optimizer exists in checkpoint")
-
- scaler = torch.cuda.amp.GradScaler()
-
- batch_size = ds.batch_size
- gradient_step = ds.gradient_step
- # n steps = batch_size * gradient_step * n image processed
- steps_per_epoch = len(ds) // batch_size // gradient_step
- max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
- loss_step = 0
- _loss_step = 0 #internal
-
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
- embedding_yet_to_be_embedded = False
-
- is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- img_c = None
-
- pbar = tqdm.tqdm(total=steps - initial_step)
- try:
- sd_hijack_checkpoint.add()
-
- for i in range((steps-initial_step) * gradient_step):
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
- for j, batch in enumerate(dl):
- # works as a drop_last=True for gradient accumulation
- if j == max_steps_per_epoch:
- break
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
- if shared.state.interrupted:
- break
-
- if clip_grad:
- clip_grad_sched.step(embedding.step)
-
- with devices.autocast():
- x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
- if use_weight:
- w = batch.weight.to(devices.device, non_blocking=pin_memory)
- c = shared.sd_model.cond_stage_model(batch.cond_text)
-
- if is_training_inpainting_model:
- if img_c is None:
- img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
-
- cond = {"c_concat": [img_c], "c_crossattn": [c]}
- else:
- cond = c
-
- if use_weight:
- loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
- del w
- else:
- loss = shared.sd_model.forward(x, cond)[0] / gradient_step
- del x
-
- _loss_step += loss.item()
- scaler.scale(loss).backward()
-
- # go back until we reach gradient accumulation steps
- if (j + 1) % gradient_step != 0:
- continue
-
- if clip_grad:
- clip_grad(embedding.vec, clip_grad_sched.learn_rate)
-
- scaler.step(optimizer)
- scaler.update()
- embedding.step += 1
- pbar.update()
- optimizer.zero_grad(set_to_none=True)
- loss_step = _loss_step
- _loss_step = 0
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // steps_per_epoch
- epoch_step = embedding.step % steps_per_epoch
-
- description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
- pbar.set_description(description)
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding_name_every = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
- "loss": f"{loss_step:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
-
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = batch.cond_text[0]
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
-
- if unload:
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.assign_current_image(image)
-
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
-
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
-
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
-
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
-
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
-
- title = "<{}>".format(data.get('name', '???'))
-
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
-
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.shorthash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
-
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
-
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
-
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
-
- shared.state.job_no = embedding.step
-
- shared.state.textinfo = f"""
-
-Loss: {loss_step:.7f}
-Step: {steps_done}
-Last prompt: {html.escape(batch.cond_text[0])}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-
-"""
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
- except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
- finally:
- pbar.leave = False
- pbar.close()
- shared.sd_model.first_stage_model.to(devices.device)
- shared.parallel_processing_allowed = old_parallel_processing_allowed
- sd_hijack_checkpoint.remove()
-
- return embedding, filename
-
-
-def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
- old_embedding_name = embedding.name
- old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
- old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
- old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
- try:
- embedding.sd_checkpoint = checkpoint.shorthash
- embedding.sd_checkpoint_name = checkpoint.model_name
- if remove_cached_checksum:
- embedding.cached_checksum = None
- embedding.name = embedding_name
- embedding.optimizer_state_dict = optimizer.state_dict()
- embedding.save(filename)
- except:
- embedding.sd_checkpoint = old_sd_checkpoint
- embedding.sd_checkpoint_name = old_sd_checkpoint_name
- embedding.name = old_embedding_name
- embedding.cached_checksum = old_cached_checksum
- raise
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
deleted file mode 100644
index 35c4feeff455b6bd2b699dd72b27f852932c78c2..0000000000000000000000000000000000000000
--- a/modules/textual_inversion/ui.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import html
-
-import gradio as gr
-
-import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared
-
-
-def create_embedding(name, initialization_text, nvpt, overwrite_old):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
-
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
- return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
-
-
-def preprocess(*args):
- modules.textual_inversion.preprocess.preprocess(*args)
-
- return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
-
-
-def train_embedding(*args):
-
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
-
- apply_optimizations = shared.opts.training_xattention_optimizations
- try:
- if not apply_optimizations:
- sd_hijack.undo_optimizations()
-
- embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
-
- res = f"""
-Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
-Embedding saved to {html.escape(filename)}
-"""
- return res, ""
- except Exception:
- raise
- finally:
- if not apply_optimizations:
- sd_hijack.apply_optimizations()
-
diff --git a/modules/timer.py b/modules/timer.py
deleted file mode 100644
index 57a4f17a16b4cb46fafb690b1180c054db284a17..0000000000000000000000000000000000000000
--- a/modules/timer.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import time
-
-
-class Timer:
- def __init__(self):
- self.start = time.time()
- self.records = {}
- self.total = 0
-
- def elapsed(self):
- end = time.time()
- res = end - self.start
- self.start = end
- return res
-
- def record(self, category, extra_time=0):
- e = self.elapsed()
- if category not in self.records:
- self.records[category] = 0
-
- self.records[category] += e + extra_time
- self.total += e + extra_time
-
- def summary(self):
- res = f"{self.total:.1f}s"
-
- additions = [x for x in self.records.items() if x[1] >= 0.1]
- if not additions:
- return res
-
- res += " ("
- res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
- res += ")"
-
- return res
diff --git a/modules/txt2img.py b/modules/txt2img.py
deleted file mode 100644
index 16841d0f2f3095ef27667ef90e5c0f1baeef733d..0000000000000000000000000000000000000000
--- a/modules/txt2img.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import modules.scripts
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
- StableDiffusionProcessingImg2Img, process_images
-from modules.shared import opts, cmd_opts
-import modules.shared as shared
-import modules.processing as processing
-from modules.ui import plaintext_to_html
-
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
- override_settings = create_override_settings_dict(override_settings_texts)
-
- p = StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
- outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
- prompt=prompt,
- styles=prompt_styles,
- negative_prompt=negative_prompt,
- seed=seed,
- subseed=subseed,
- subseed_strength=subseed_strength,
- seed_resize_from_h=seed_resize_from_h,
- seed_resize_from_w=seed_resize_from_w,
- seed_enable_extras=seed_enable_extras,
- sampler_name=sd_samplers.samplers[sampler_index].name,
- batch_size=batch_size,
- n_iter=n_iter,
- steps=steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- restore_faces=restore_faces,
- tiling=tiling,
- enable_hr=enable_hr,
- denoising_strength=denoising_strength if enable_hr else None,
- hr_scale=hr_scale,
- hr_upscaler=hr_upscaler,
- hr_second_pass_steps=hr_second_pass_steps,
- hr_resize_x=hr_resize_x,
- hr_resize_y=hr_resize_y,
- override_settings=override_settings,
- )
-
- p.scripts = modules.scripts.scripts_txt2img
- p.script_args = args
-
- if cmd_opts.enable_console_prompts:
- print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
-
- processed = modules.scripts.scripts_txt2img.run(p, *args)
-
- if processed is None:
- processed = process_images(p)
-
- p.close()
-
- shared.total_tqdm.clear()
-
- generation_info_js = processed.js()
- if opts.samples_log_stdout:
- print(generation_info_js)
-
- if opts.do_not_show_images:
- processed.images = []
-
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
deleted file mode 100644
index 6fe8f4db76de26509f865f5450e23282466ad860..0000000000000000000000000000000000000000
--- a/modules/ui.py
+++ /dev/null
@@ -1,1800 +0,0 @@
-import html
-import json
-import math
-import mimetypes
-import os
-import platform
-import random
-import sys
-import tempfile
-import time
-import traceback
-from functools import partial, reduce
-import warnings
-
-import gradio as gr
-import gradio.routes
-import gradio.utils
-import numpy as np
-from PIL import Image, PngImagePlugin
-from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
-from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
-from modules.paths import script_path, data_path
-
-from modules.shared import opts, cmd_opts, restricted_opts
-
-import modules.codeformer_model
-import modules.generation_parameters_copypaste as parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.scripts
-import modules.shared as shared
-import modules.styles
-import modules.textual_inversion.ui
-from modules import prompt_parser
-from modules.images import save_image
-from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.textual_inversion import textual_inversion
-import modules.hypernetworks.ui
-from modules.generation_parameters_copypaste import image_from_url_text
-import modules.extras
-
-warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
-
-# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
-mimetypes.init()
-mimetypes.add_type('application/javascript', '.js')
-
-if not cmd_opts.share and not cmd_opts.listen:
- # fix gradio phoning home
- gradio.utils.version_check = lambda: None
- gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
-
-if cmd_opts.ngrok is not None:
- import modules.ngrok as ngrok
- print('ngrok authtoken detected, trying to connect...')
- ngrok.connect(
- cmd_opts.ngrok,
- cmd_opts.port if cmd_opts.port is not None else 7860,
- cmd_opts.ngrok_region
- )
-
-
-def gr_show(visible=True):
- return {"visible": visible, "__type__": "update"}
-
-
-sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
-sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
-
-css_hide_progressbar = """
-.wrap .m-12 svg { display:none!important; }
-.wrap .m-12::before { content:"Loading..." }
-.wrap .z-20 svg { display:none!important; }
-.wrap .z-20::before { content:"Loading..." }
-.wrap.cover-bg .z-20::before { content:"" }
-.progress-bar { display:none!important; }
-.meta-text { display:none!important; }
-.meta-text-center { display:none!important; }
-"""
-
-# Using constants for these since the variation selector isn't visible.
-# Important that they exactly match script.js for tooltip to work.
-random_symbol = '\U0001f3b2\ufe0f' # 🎲️
-reuse_symbol = '\u267b\ufe0f' # ♻️
-paste_symbol = '\u2199\ufe0f' # ↙
-refresh_symbol = '\U0001f504' # 🔄
-save_style_symbol = '\U0001f4be' # 💾
-apply_style_symbol = '\U0001f4cb' # 📋
-clear_prompt_symbol = '\U0001F5D1' # 🗑️
-extra_networks_symbol = '\U0001F3B4' # 🎴
-switch_values_symbol = '\U000021C5' # ⇅
-
-
-def plaintext_to_html(text):
- return ui_common.plaintext_to_html(text)
-
-
-def send_gradio_gallery_to_image(x):
- if len(x) == 0:
- return None
- return image_from_url_text(x[0])
-
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(path + "/" + str(x.label), x)
-
-
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
-def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
- from modules import processing, devices
-
- if not enable:
- return ""
-
- p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
-
- with devices.autocast():
- p.init([""], [0], [0])
-
- return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y} "
-
-
-def apply_styles(prompt, prompt_neg, styles):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
-def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
- if mode in {0, 1, 3, 4}:
- return [interrogation_function(ii_singles[mode]), None]
- elif mode == 2:
- return [interrogation_function(ii_singles[mode]["image"]), None]
- elif mode == 5:
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- images = shared.listfiles(ii_input_dir)
- print(f"Will process {len(images)} images.")
- if ii_output_dir != "":
- os.makedirs(ii_output_dir, exist_ok=True)
- else:
- ii_output_dir = ii_input_dir
-
- for image in images:
- img = Image.open(image)
- filename = os.path.basename(image)
- left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
-
- return [gr.update(), None]
-
-
-def interrogate(image):
- prompt = shared.interrogator.interrogate(image.convert("RGB"))
- return gr.update() if prompt is None else prompt
-
-
-def interrogate_deepbooru(image):
- prompt = deepbooru.model.tag(image)
- return gr.update() if prompt is None else prompt
-
-
-def create_seed_inputs(target_interface):
- with FormRow(elem_id=target_interface + '_seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
-
- with gr.Group(elem_id=target_interface + '_subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
-
- # Components to show/hide based on the 'Extra' checkbox
- seed_extras = []
-
- with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
- seed_extras.append(seed_extra_row_1)
- subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
-
- with FormRow(visible=False) as seed_extra_row_2:
- seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
-
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
-
- def change_visibility(show):
- return {comp: gr_show(show) for comp in seed_extras}
-
- seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
-
- return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
-
-
-def connect_clear_prompt(button):
- """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
- button.click(
- _js="clear_prompt",
- fn=None,
- inputs=[],
- outputs=[],
- )
-
-
-def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
- """ Connects a 'reuse (sub)seed' button's click event so that it copies last used
- (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
- was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
- def copy_seed(gen_info_string: str, index):
- res = -1
-
- try:
- gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
-
- return [res, gr_show(False)]
-
- reuse_seed.click(
- fn=copy_seed,
- _js="(x, y) => [x, selected_gallery_index()]",
- show_progress=False,
- inputs=[generation_info, dummy_component],
- outputs=[seed, dummy_component]
- )
-
-
-def update_token_counter(text, steps):
- try:
- text, _ = extra_networks.parse_prompt(text)
-
- _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
- prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
-
- except Exception:
- # a parsing error can happen here during typing, and we don't want to bother the user with
- # messages related to it in console
- prompt_schedules = [[[steps, text]]]
-
- flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
- prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- return f"{token_count}/{max_length} "
-
-
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
-
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_id="interrogate_col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row(elem_id=f"{id_part}_tools"):
- paste = ToolButton(value=paste_symbol, elem_id="paste")
- clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
- prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
- save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
-
- token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_negative_token_counter")
- negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
-
- with gr.Row(elem_id=f"{id_part}_styles_row"):
- prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
- create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
-
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
-
-
-def setup_progressbar(*args, **kwargs):
- pass
-
-
-def apply_setting(key, value):
- if value is None:
- return gr.update()
-
- if shared.cmd_opts.freeze_settings:
- return gr.update()
-
- # dont allow model to be swapped when model hash exists in prompt
- if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
- return gr.update()
-
- if key == "sd_model_checkpoint":
- ckpt_info = sd_models.get_closet_checkpoint_match(value)
-
- if ckpt_info is not None:
- value = ckpt_info.title
- else:
- return gr.update()
-
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- return
-
- valtype = type(opts.data_labels[key].default)
- oldval = opts.data.get(key, None)
- opts.data[key] = valtype(value) if valtype != type(None) else value
- if oldval != value and opts.data_labels[key].onchange is not None:
- opts.data_labels[key].onchange()
-
- opts.save(shared.config_filename)
- return getattr(opts, key)
-
-
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
-def create_output_panel(tabname, outdir):
- return ui_common.create_output_panel(tabname, outdir)
-
-
-def create_sampler_and_steps_selection(choices, tabname):
- if opts.samplers_in_dropdown:
- with FormRow(elem_id=f"sampler_selection_{tabname}"):
- sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- else:
- with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
-
- return steps, sampler_index
-
-
-def ordered_ui_categories():
- user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
-
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
- yield category
-
-
-def get_value_for_setting(key):
- value = getattr(opts, key)
-
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
-
- return gr.update(value=value, **args)
-
-
-def create_override_settings_dropdown(tabname, row):
- dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
-
- dropdown.change(
- fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
- inputs=[dropdown],
- outputs=[dropdown],
- )
-
- return dropdown
-
-
-def create_ui():
- import modules.img2img
- import modules.txt2img
-
- reload_javascript()
-
- parameters_copypaste.reset()
-
- modules.scripts.scripts_current = modules.scripts.scripts_txt2img
- modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
-
- with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
-
- dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
-
- with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
- from modules import ui_extra_networks
- extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
-
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='compact', elem_id="txt2img_settings"):
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
-
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "cfg":
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
- enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
-
- elif category == "hires_fix":
- with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
-
- with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="txt2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
-
- elif category == "override_settings":
- with FormRow(elem_id="txt2img_override_settings_row") as row:
- override_settings = create_override_settings_dropdown('txt2img', row)
-
- elif category == "scripts":
- with FormGroup(elem_id="txt2img_script_container"):
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
-
- hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False,
- )
- input.change(
- None,
- _js="onCalcResolutionHires",
- inputs=hr_resolution_preview_inputs,
- outputs=[],
- show_progress=False,
- )
-
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
- _js="submit",
- inputs=[
- dummy_component,
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_styles,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- override_settings,
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
-
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
-
- txt_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- txt_prompt_img
- ],
- outputs=[
- txt2img_prompt,
- txt_prompt_img
- ]
- )
-
- enable_hr.change(
- fn=lambda x: gr_show(x),
- inputs=[enable_hr],
- outputs=[hr_options],
- show_progress = False,
- )
-
- txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d),
- (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- *modules.scripts.scripts_txt2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
- parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
- ))
-
- txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
- steps,
- sampler_index,
- cfg_scale,
- seed,
- width,
- height,
- ]
-
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
-
- ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
-
- modules.scripts.scripts_current = modules.scripts.scripts_img2img
- modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
-
- with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
-
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
-
- with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
- from modules import ui_extra_networks
- extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
-
- with FormRow().style(equal_height=False):
- with gr.Column(variant='compact', elem_id="img2img_settings"):
- copy_image_buttons = []
- copy_image_destinations = {}
-
- def add_copy_image_controls(tab_name, elem):
- with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
- gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
-
- for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
- if name == tab_name:
- gr.Button(title, interactive=False)
- copy_image_destinations[name] = elem
- continue
-
- button = gr.Button(title)
- copy_image_buttons.append((button, name, elem))
-
- with gr.Tabs(elem_id="mode_img2img"):
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
- add_copy_image_controls('img2img', init_img)
-
- with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
- add_copy_image_controls('sketch', sketch)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
- add_copy_image_controls('inpaint', init_img_with_mask)
-
- with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
- inpaint_color_sketch_orig = gr.State(None)
- add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
-
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
-
- with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
-
- with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
- hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(
- f"Process images in a directory on the same machine where the server is running." +
- f" Use an empty output directory to save pictures normally instead of writing to the output directory." +
- f" Add inpaint batch mask directory to enable inpaint batch processing."
- f"{hidden}
"
- )
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
-
- def copy_image(img):
- if isinstance(img, dict) and 'image' in img:
- return img['image']
-
- return img
-
- for button, name, elem in copy_image_buttons:
- button.click(
- fn=copy_image,
- inputs=[elem],
- outputs=[copy_image_destinations[name]],
- )
- button.click(
- fn=lambda: None,
- _js="switch_to_"+name.replace(" ", "_"),
- inputs=[],
- outputs=[],
- )
-
- with FormRow():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
-
- for category in ordered_ui_categories():
- if category == "sampler":
- steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
-
- elif category == "dimensions":
- with FormRow():
- with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
-
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
- if opts.dimensions_and_batch_together:
- with gr.Column(elem_id="img2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
-
- elif category == "cfg":
- with FormGroup():
- with FormRow():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
- image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
-
- elif category == "seed":
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
-
- elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes", variant="compact"):
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
- tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
-
- elif category == "batch":
- if not opts.dimensions_and_batch_together:
- with FormRow(elem_id="img2img_column_batch"):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
-
- elif category == "override_settings":
- with FormRow(elem_id="img2img_override_settings_row") as row:
- override_settings = create_override_settings_dropdown('img2img', row)
-
- elif category == "scripts":
- with FormGroup(elem_id="img2img_script_container"):
- custom_inputs = modules.scripts.scripts_img2img.setup_ui()
-
- elif category == "inpaint":
- with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
- with FormRow():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
- mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
-
- with FormRow():
- inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
-
- with FormRow():
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
-
- with FormRow():
- with gr.Column():
- inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
-
- with gr.Column(scale=4):
- inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
-
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
- inputs=[],
- outputs=[inpaint_controls, mask_alpha],
- )
-
- img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
-
- connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
- connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
-
- img2img_prompt_img.change(
- fn=modules.images.image_data,
- inputs=[
- img2img_prompt_img
- ],
- outputs=[
- img2img_prompt,
- img2img_prompt_img
- ]
- )
-
- img2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
- _js="submit_img2img",
- inputs=[
- dummy_component,
- dummy_component,
- img2img_prompt,
- img2img_negative_prompt,
- img2img_prompt_styles,
- init_img,
- sketch,
- init_img_with_mask,
- inpaint_color_sketch,
- inpaint_color_sketch_orig,
- init_img_inpaint,
- init_mask_inpaint,
- steps,
- sampler_index,
- mask_blur,
- mask_alpha,
- inpainting_fill,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- image_cfg_scale,
- denoising_strength,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- resize_mode,
- inpaint_full_res,
- inpaint_full_res_padding,
- inpainting_mask_invert,
- img2img_batch_input_dir,
- img2img_batch_output_dir,
- img2img_batch_inpaint_mask_dir,
- override_settings,
- ] + custom_inputs,
- outputs=[
- img2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
- show_progress=False,
- )
-
- interrogate_args = dict(
- _js="get_img2img_tab_index",
- inputs=[
- dummy_component,
- img2img_batch_input_dir,
- img2img_batch_output_dir,
- init_img,
- sketch,
- init_img_with_mask,
- inpaint_color_sketch,
- init_img_inpaint,
- ],
- outputs=[img2img_prompt, dummy_component],
- )
-
- img2img_prompt.submit(**img2img_args)
- submit.click(**img2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
-
- img2img_interrogate.click(
- fn=lambda *args: process_interrogate(interrogate, *args),
- **interrogate_args,
- )
-
- img2img_deepbooru.click(
- fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
- **interrogate_args,
- )
-
- prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
- style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
- button.click(
- fn=add_style,
- _js="ask_for_style_name",
- # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
- # the same number of parameters, but we only know the style-name after the JavaScript prompt
- inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_styles, img2img_prompt_styles],
- )
-
- for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
- button.click(
- fn=apply_styles,
- _js=js_func,
- inputs=[prompt, negative_prompt, styles],
- outputs=[prompt, negative_prompt, styles],
- )
-
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
-
- ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
-
- img2img_paste_fields = [
- (img2img_prompt, "Prompt"),
- (img2img_negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_index, "Sampler"),
- (restore_faces, "Face restoration"),
- (cfg_scale, "CFG scale"),
- (image_cfg_scale, "Image CFG scale"),
- (seed, "Seed"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
- (denoising_strength, "Denoising strength"),
- (mask_blur, "Mask blur"),
- *modules.scripts.scripts_img2img.infotext_fields
- ]
- parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
- parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
- parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
- ))
-
- modules.scripts.scripts_current = None
-
- with gr.Blocks(analytics_enabled=False) as extras_interface:
- ui_postprocessing.create_ui()
-
- with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
-
- with gr.Column(variant='panel'):
- html = gr.HTML()
- generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
- html2 = gr.HTML()
- with gr.Row():
- buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
-
- for tabname, button in buttons.items():
- parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
- ))
-
- image.change(
- fn=wrap_gradio_call(modules.extras.run_pnginfo),
- inputs=[image],
- outputs=[html, generation_info, html2],
- )
-
- def update_interp_description(value):
- interp_description_css = "{}
"
- interp_descriptions = {
- "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
- "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
- "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
- }
- return interp_descriptions[value]
-
- with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='compact'):
- interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
-
- with FormRow(elem_id="modelmerger_models"):
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
- create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
-
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
- create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
-
- tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
- create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
-
- custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
- interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
-
- with FormRow():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
- save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
-
- with FormRow():
- with gr.Column():
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
-
- with gr.Column():
- with FormRow():
- bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
- create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
-
- with FormRow():
- discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
-
- with gr.Row():
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
-
- with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
- with gr.Group(elem_id="modelmerger_results_panel"):
- modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
-
- with gr.Blocks(analytics_enabled=False) as train_interface:
- with gr.Row().style(equal_height=False):
- gr.HTML(value="See wiki for detailed explanation.
")
-
- with gr.Row(variant="compact").style(equal_height=False):
- with gr.Tabs(elem_id="train_tabs"):
-
- with gr.Tab(label="Create embedding"):
- new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
- initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
- nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
- overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
-
- with gr.Tab(label="Create hypernetwork"):
- new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
- new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
- new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
- new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
- new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
- overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
-
- with gr.Tab(label="Preprocess images"):
- process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
- process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
-
- with gr.Row():
- process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
- process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
- process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
- process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
- process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
-
- with gr.Row(visible=False) as process_split_extra_row:
- process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
- process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
-
- with gr.Row(visible=False) as process_focal_crop_row:
- process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
- process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
- process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
- with gr.Column(visible=False) as process_multicrop_col:
- gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
- with gr.Row():
- process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
- process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
- with gr.Row():
- process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
- process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
- with gr.Row():
- process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
- process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
- run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
-
- process_split.change(
- fn=lambda show: gr_show(show),
- inputs=[process_split],
- outputs=[process_split_extra_row],
- )
-
- process_focal_crop.change(
- fn=lambda show: gr_show(show),
- inputs=[process_focal_crop],
- outputs=[process_focal_crop_row],
- )
-
- process_multicrop.change(
- fn=lambda show: gr_show(show),
- inputs=[process_multicrop],
- outputs=[process_multicrop_col],
- )
-
- def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
-
- with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
- with FormRow():
- train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
-
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
-
- with FormRow():
- embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
- hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
-
- with FormRow():
- clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
- clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
-
- with FormRow():
- batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
- gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
-
- dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
- log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
-
- with FormRow():
- template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
- create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
-
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
- varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
- steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
-
- with FormRow():
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
-
- use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
-
- save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
- preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
-
- shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
- tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
-
- latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
-
- with gr.Row():
- train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
- interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
- train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
-
- params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
-
- script_callbacks.ui_train_tabs_callback(params)
-
- with gr.Column(elem_id='ti_gallery_container'):
- ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
- ti_progress = gr.HTML(elem_id="ti_progress", value="")
- ti_outcome = gr.HTML(elem_id="ti_error", value="")
-
- create_embedding.click(
- fn=modules.textual_inversion.ui.create_embedding,
- inputs=[
- new_embedding_name,
- initialization_text,
- nvpt,
- overwrite_old_embedding,
- ],
- outputs=[
- train_embedding_name,
- ti_output,
- ti_outcome,
- ]
- )
-
- create_hypernetwork.click(
- fn=modules.hypernetworks.ui.create_hypernetwork,
- inputs=[
- new_hypernetwork_name,
- new_hypernetwork_sizes,
- overwrite_old_hypernetwork,
- new_hypernetwork_layer_structure,
- new_hypernetwork_activation_func,
- new_hypernetwork_initialization_option,
- new_hypernetwork_add_layer_norm,
- new_hypernetwork_use_dropout,
- new_hypernetwork_dropout_structure
- ],
- outputs=[
- train_hypernetwork_name,
- ti_output,
- ti_outcome,
- ]
- )
-
- run_preprocess.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- dummy_component,
- process_src,
- process_dst,
- process_width,
- process_height,
- preprocess_txt_action,
- process_flip,
- process_split,
- process_caption,
- process_caption_deepbooru,
- process_split_threshold,
- process_overlap_ratio,
- process_focal_crop,
- process_focal_crop_face_weight,
- process_focal_crop_entropy_weight,
- process_focal_crop_edges_weight,
- process_focal_crop_debug,
- process_multicrop,
- process_multicrop_mindim,
- process_multicrop_maxdim,
- process_multicrop_minarea,
- process_multicrop_maxarea,
- process_multicrop_objective,
- process_multicrop_threshold,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ],
- )
-
- train_embedding.click(
- fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- dummy_component,
- train_embedding_name,
- embedding_learn_rate,
- batch_size,
- gradient_step,
- dataset_directory,
- log_directory,
- training_width,
- training_height,
- varsize,
- steps,
- clip_grad_mode,
- clip_grad_value,
- shuffle_tags,
- tag_drop_out,
- latent_sampling_method,
- use_weight,
- create_image_every,
- save_embedding_every,
- template_file,
- save_image_with_stored_embedding,
- preview_from_txt2img,
- *txt2img_preview_params,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ]
- )
-
- train_hypernetwork.click(
- fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- dummy_component,
- train_hypernetwork_name,
- hypernetwork_learn_rate,
- batch_size,
- gradient_step,
- dataset_directory,
- log_directory,
- training_width,
- training_height,
- varsize,
- steps,
- clip_grad_mode,
- clip_grad_value,
- shuffle_tags,
- tag_drop_out,
- latent_sampling_method,
- use_weight,
- create_image_every,
- save_embedding_every,
- template_file,
- preview_from_txt2img,
- *txt2img_preview_params,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ]
- )
-
- interrupt_training.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- interrupt_preprocessing.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- def create_setting_component(key, is_quicksettings=False):
- def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
- info = opts.data_labels[key]
- t = type(info.default)
-
- args = info.component_args() if callable(info.component_args) else info.component_args
-
- if info.component is not None:
- comp = info.component
- elif t == str:
- comp = gr.Textbox
- elif t == int:
- comp = gr.Number
- elif t == bool:
- comp = gr.Checkbox
- else:
- raise Exception(f'bad options item type: {str(t)} for key {key}')
-
- elem_id = "setting_"+key
-
- if info.refresh is not None:
- if is_quicksettings:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
- else:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
- else:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-
- return res
-
- components = []
- component_dict = {}
- shared.settings_components = component_dict
-
- script_callbacks.ui_settings_callback()
- opts.reorder()
-
- def run_settings(*args):
- changed = []
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp == dummy_component:
- continue
-
- if opts.set(key, value):
- changed.append(key)
-
- try:
- opts.save(shared.config_filename)
- except RuntimeError:
- return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
-
- def run_settings_single(value, key):
- if not opts.same_type(value, opts.data_labels[key].default):
- return gr.update(visible=True), opts.dumpjson()
-
- if not opts.set(key, value):
- return gr.update(value=getattr(opts, key)), opts.dumpjson()
-
- opts.save(shared.config_filename)
-
- return get_value_for_setting(key), opts.dumpjson()
-
- with gr.Blocks(analytics_enabled=False) as settings_interface:
- with gr.Row():
- with gr.Column(scale=6):
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- with gr.Column():
- restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
-
- result = gr.HTML(elem_id="settings_result")
-
- quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
- quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
-
- quicksettings_list = []
-
- previous_section = None
- current_tab = None
- current_row = None
- with gr.Tabs(elem_id="settings"):
- for i, (k, item) in enumerate(opts.data_labels.items()):
- section_must_be_skipped = item.section[0] is None
-
- if previous_section != item.section and not section_must_be_skipped:
- elem_id, text = item.section
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- gr.Group()
- current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
- current_tab.__enter__()
- current_row = gr.Column(variant='compact')
- current_row.__enter__()
-
- previous_section = item.section
-
- if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
- quicksettings_list.append((i, k, item))
- components.append(dummy_component)
- elif section_must_be_skipped:
- components.append(dummy_component)
- else:
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- with gr.TabItem("Actions"):
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
-
- with gr.TabItem("Licenses"):
- gr.HTML(shared.html("licenses.html"), elem_id="licenses")
-
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
- request_notifications.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='function(){}'
- )
-
- download_localization.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='download_localization'
- )
-
- def reload_scripts():
- modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
-
- reload_script_bodies.click(
- fn=reload_scripts,
- inputs=[],
- outputs=[]
- )
-
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
- restart_gradio.click(
- fn=request_restart,
- _js='restart_reload',
- inputs=[],
- outputs=[],
- )
-
- interfaces = [
- (txt2img_interface, "txt2img", "txt2img"),
- (img2img_interface, "img2img", "img2img"),
- (extras_interface, "Extras", "extras"),
- (pnginfo_interface, "PNG Info", "pnginfo"),
- (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
- ]
-
- css = ""
-
- for cssfile in modules.scripts.list_files_with_name("style.css"):
- if not os.path.isfile(cssfile):
- continue
-
- with open(cssfile, "r", encoding="utf8") as file:
- css += file.read() + "\n"
-
- if os.path.exists(os.path.join(data_path, "user.css")):
- with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
- css += file.read() + "\n"
-
- if not cmd_opts.no_progressbar_hiding:
- css += css_hide_progressbar
-
- interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
-
- extensions_interface = ui_extensions.create_ui()
- interfaces += [(extensions_interface, "Extensions", "extensions")]
-
- with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings", variant="compact"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
- component = create_setting_component(k, is_quicksettings=True)
- component_dict[k] = component
-
- parameters_copypaste.connect_paste_params_buttons()
-
- with gr.Tabs(elem_id="tabs") as tabs:
- for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
- interface.render()
-
- if os.path.exists(os.path.join(script_path, "notification.mp3")):
- audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
-
- footer = shared.html("footer.html")
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
-
- text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
- settings_submit.click(
- fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
- inputs=components,
- outputs=[text_settings, result],
- )
-
- for i, k, item in quicksettings_list:
- component = component_dict[k]
-
- component.change(
- fn=lambda value, k=k: run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, text_settings],
- )
-
- text_settings.change(
- fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
- inputs=[],
- outputs=[image_cfg_scale],
- )
-
- button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
- button_set_checkpoint.click(
- fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
- _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_model_checkpoint'], dummy_component],
- outputs=[component_dict['sd_model_checkpoint'], text_settings],
- )
-
- component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
- def get_settings_values():
- return [get_value_for_setting(key) for key in component_keys]
-
- demo.load(
- fn=get_settings_values,
- inputs=[],
- outputs=[component_dict[k] for k in component_keys],
- )
-
- def modelmerger(*args):
- try:
- results = modules.extras.run_modelmerger(*args)
- except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
- return results
-
- modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
- modelmerger_merge.click(
- fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
- _js='modelmerger',
- inputs=[
- dummy_component,
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- interp_method,
- interp_amount,
- save_as_half,
- custom_name,
- checkpoint_format,
- config_source,
- bake_in_vae,
- discard_weights,
- ],
- outputs=[
- primary_model_name,
- secondary_model_name,
- tertiary_model_name,
- component_dict['sd_model_checkpoint'],
- modelmerger_result,
- ]
- )
-
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = path + "/" + field
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = 'customscript/' + obj.custom_script_source + '/' + key
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- pass
-
- # this warning is generally not useful;
- # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- def check_dropdown(val):
- if getattr(x, 'multiselect', False):
- return all([value in x.choices for value in val])
- else:
- return val in x.choices
-
- apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
- visit(settings_interface, loadsave, "settings")
- visit(extensions_interface, loadsave, "extensions")
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
-
- # Required as a workaround for change() event not triggering when loading values from ui-config.json
- interp_description.value = update_interp_description(interp_method.value)
-
- return demo
-
-
-def reload_javascript():
- head = f'\n'
-
- inline = f"{localization.localization_js(shared.opts.localization)};"
- if cmd_opts.theme is not None:
- inline += f"set_theme('{cmd_opts.theme}');"
-
- for script in modules.scripts.list_scripts("javascript", ".js"):
- head += f'\n'
-
- head += f'\n'
-
- def template_response(*args, **kwargs):
- res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(b'', f'{head}'.encode("utf8"))
- res.init_headers()
- return res
-
- gradio.routes.templates.TemplateResponse = template_response
-
-
-if not hasattr(shared, 'GradioTemplateResponseOriginal'):
- shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
-
-
-def versions_html():
- import torch
- import launch
-
- python_version = ".".join([str(x) for x in sys.version_info[0:3]])
- commit = launch.commit_hash()
- short_commit = commit[0:8]
-
- if shared.xformers_available:
- import xformers
- xformers_version = xformers.__version__
- else:
- xformers_version = "N/A"
-
- return f"""
-python: {python_version}
- •
-torch: {getattr(torch, '__long_version__',torch.__version__)}
- •
-xformers: {xformers_version}
- •
-gradio: {gr.__version__}
- •
-commit: {short_commit}
- •
-checkpoint: N/A
-"""
diff --git a/modules/ui_common.py b/modules/ui_common.py
deleted file mode 100644
index fd047f318879ebb038981aeaa0f96549e1c91bcf..0000000000000000000000000000000000000000
--- a/modules/ui_common.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import json
-import html
-import os
-import platform
-import sys
-
-import gradio as gr
-import subprocess as sp
-
-from modules import call_queue, shared
-from modules.generation_parameters_copypaste import image_from_url_text
-import modules.images
-
-folder_symbol = '\U0001f4c2' # 📂
-
-
-def update_generation_info(generation_info, html_info, img_index):
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info, gr.update()
- return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info, gr.update()
-
-
-def plaintext_to_html(text):
- text = "" + " \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "
"
- return text
-
-
-def save_files(js_data, images, do_make_zip, index):
- import csv
- filenames = []
- fullfns = []
-
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
- class MyObject:
- def __init__(self, d=None):
- if d is not None:
- for key, value in d.items():
- setattr(self, key, value)
-
- data = json.loads(js_data)
-
- p = MyObject(data)
- path = shared.opts.outdir_save
- save_to_dirs = shared.opts.use_save_to_dirs_for_ui
- extension: str = shared.opts.samples_format
- start_index = 0
-
- if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
- images = [images[index]]
- start_index = index
-
- os.makedirs(shared.opts.outdir_save, exist_ok=True)
-
- with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
- at_start = file.tell() == 0
- writer = csv.writer(file)
- if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
-
- for image_index, filedata in enumerate(images, start_index):
- image = image_from_url_text(filedata)
-
- is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
-
- fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
-
- filename = os.path.relpath(fullfn, path)
- filenames.append(filename)
- fullfns.append(fullfn)
- if txt_fullfn:
- filenames.append(os.path.basename(txt_fullfn))
- fullfns.append(txt_fullfn)
-
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
-
- # Make Zip
- if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
-
- from zipfile import ZipFile
- with ZipFile(zip_filepath, "w") as zip_file:
- for i in range(len(fullfns)):
- with open(fullfns[i], mode="rb") as f:
- zip_file.writestr(filenames[i], f.read())
- fullfns.insert(0, zip_filepath)
-
- return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
-def create_output_panel(tabname, outdir):
- from modules import shared
- import modules.generation_parameters_copypaste as parameters_copypaste
-
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
- inputs=[generation_info, html_info, html_info],
- outputs=[html_info, html_info],
- )
-
- save.click(
- fn=call_queue.wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ],
- show_progress=False,
- )
-
- save_zip.click(
- fn=call_queue.wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- for paste_tabname, paste_button in buttons.items():
- parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
- ))
-
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py
deleted file mode 100644
index 284ca0cf2d6951be1b2cedeffe5fb6f11000c78b..0000000000000000000000000000000000000000
--- a/modules/ui_components.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import gradio as gr
-
-
-class ToolButton(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, fits inside gradio forms"""
-
- def __init__(self, **kwargs):
- super().__init__(variant="tool", **kwargs)
-
- def get_block_name(self):
- return "button"
-
-
-class ToolButtonTop(gr.Button, gr.components.FormComponent):
- """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
-
- def __init__(self, **kwargs):
- super().__init__(variant="tool-top", **kwargs)
-
- def get_block_name(self):
- return "button"
-
-
-class FormRow(gr.Row, gr.components.FormComponent):
- """Same as gr.Row but fits inside gradio forms"""
-
- def get_block_name(self):
- return "row"
-
-
-class FormGroup(gr.Group, gr.components.FormComponent):
- """Same as gr.Row but fits inside gradio forms"""
-
- def get_block_name(self):
- return "group"
-
-
-class FormHTML(gr.HTML, gr.components.FormComponent):
- """Same as gr.HTML but fits inside gradio forms"""
-
- def get_block_name(self):
- return "html"
-
-
-class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
- """Same as gr.ColorPicker but fits inside gradio forms"""
-
- def get_block_name(self):
- return "colorpicker"
-
-
-class DropdownMulti(gr.Dropdown):
- """Same as gr.Dropdown but always multiselect"""
- def __init__(self, **kwargs):
- super().__init__(multiselect=True, **kwargs)
-
- def get_block_name(self):
- return "dropdown"
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
deleted file mode 100644
index bd4308ef028c55dfdfb22ec6dc36702f274445a3..0000000000000000000000000000000000000000
--- a/modules/ui_extensions.py
+++ /dev/null
@@ -1,354 +0,0 @@
-import json
-import os.path
-import shutil
-import sys
-import time
-import traceback
-
-import git
-
-import gradio as gr
-import html
-import shutil
-import errno
-
-from modules import extensions, shared, paths
-from modules.call_queue import wrap_gradio_gpu_call
-
-available_extensions = {"extensions": []}
-
-
-def check_access():
- assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
-
-
-def apply_and_restart(disable_list, update_list):
- check_access()
-
- disabled = json.loads(disable_list)
- assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
-
- update = json.loads(update_list)
- assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
-
- update = set(update)
-
- for ext in extensions.extensions:
- if ext.name not in update:
- continue
-
- try:
- ext.fetch_and_reset_hard()
- except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- shared.opts.disabled_extensions = disabled
- shared.opts.save(shared.config_filename)
-
- shared.state.interrupt()
- shared.state.need_restart = True
-
-
-def check_updates(id_task, disable_list):
- check_access()
-
- disabled = json.loads(disable_list)
- assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
-
- exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
- shared.state.job_count = len(exts)
-
- for ext in exts:
- shared.state.textinfo = ext.name
-
- try:
- ext.check_updates()
- except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- shared.state.nextjob()
-
- return extension_table(), ""
-
-
-def extension_table():
- code = f"""
-
- """
-
- return code
-
-
-def normalize_git_url(url):
- if url is None:
- return ""
-
- url = url.replace(".git", "")
- return url
-
-
-def install_extension_from_url(dirname, url):
- check_access()
-
- assert url, 'No URL specified'
-
- if dirname is None or dirname == "":
- *parts, last_part = url.split('/')
- last_part = normalize_git_url(last_part)
-
- dirname = last_part
-
- target_dir = os.path.join(extensions.extensions_dir, dirname)
- assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
-
- normalized_url = normalize_git_url(url)
- assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
-
- tmpdir = os.path.join(paths.data_path, "tmp", dirname)
-
- try:
- shutil.rmtree(tmpdir, True)
-
- repo = git.Repo.clone_from(url, tmpdir)
- repo.remote().fetch()
-
- try:
- os.rename(tmpdir, target_dir)
- except OSError as err:
- # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
- # Shouldn't cause any new issues at least but we probably want to handle it there too.
- if err.errno == errno.EXDEV:
- # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
- # Since we can't use a rename, do the slower but more versitile shutil.move()
- shutil.move(tmpdir, target_dir)
- else:
- # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
- raise(err)
-
- import launch
- launch.run_extension_installer(target_dir)
-
- extensions.list_extensions()
- return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
- finally:
- shutil.rmtree(tmpdir, True)
-
-
-def install_extension_from_index(url, hide_tags, sort_column):
- ext_table, message = install_extension_from_url(None, url)
-
- code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
-
- return code, ext_table, message
-
-
-def refresh_available_extensions(url, hide_tags, sort_column):
- global available_extensions
-
- import urllib.request
- with urllib.request.urlopen(url) as response:
- text = response.read()
-
- available_extensions = json.loads(text)
-
- code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
-
- return url, code, gr.CheckboxGroup.update(choices=tags), ''
-
-
-def refresh_available_extensions_for_tags(hide_tags, sort_column):
- code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
-
- return code, ''
-
-
-sort_ordering = [
- # (reverse, order_by_function)
- (True, lambda x: x.get('added', 'z')),
- (False, lambda x: x.get('added', 'z')),
- (False, lambda x: x.get('name', 'z')),
- (True, lambda x: x.get('name', 'z')),
- (False, lambda x: 'z'),
-]
-
-
-def refresh_available_extensions_from_data(hide_tags, sort_column):
- extlist = available_extensions["extensions"]
- installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
-
- tags = available_extensions.get("tags", {})
- tags_to_hide = set(hide_tags)
- hidden = 0
-
- code = f"""
-
-
-
- Extension
- Description
- Action
-
-
-
- """
-
- sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
-
- for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
- name = ext.get("name", "noname")
- added = ext.get('added', 'unknown')
- url = ext.get("url", None)
- description = ext.get("description", "")
- extension_tags = ext.get("tags", [])
-
- if url is None:
- continue
-
- existing = installed_extension_urls.get(normalize_git_url(url), None)
- extension_tags = extension_tags + ["installed"] if existing else extension_tags
-
- if len([x for x in extension_tags if x in tags_to_hide]) > 0:
- hidden += 1
- continue
-
- install_code = f""" """
-
- tags_text = ", ".join([f"{x} " for x in extension_tags])
-
- code += f"""
-
- {html.escape(name)} {tags_text}
- {html.escape(description)}Added: {html.escape(added)}
- {install_code}
-
-
- """
-
- for tag in [x for x in extension_tags if x not in tags]:
- tags[tag] = tag
-
- code += """
-
-
- """
-
- if hidden > 0:
- code += f"Extension hidden: {hidden}
"
-
- return code, list(tags)
-
-
-def create_ui():
- import modules.ui
-
- with gr.Blocks(analytics_enabled=False) as ui:
- with gr.Tabs(elem_id="tabs_extensions") as tabs:
- with gr.TabItem("Installed"):
-
- with gr.Row(elem_id="extensions_installed_top"):
- apply = gr.Button(value="Apply and restart UI", variant="primary")
- check = gr.Button(value="Check for updates")
- extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
- extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
-
- info = gr.HTML()
- extensions_table = gr.HTML(lambda: extension_table())
-
- apply.click(
- fn=apply_and_restart,
- _js="extensions_apply",
- inputs=[extensions_disabled_list, extensions_update_list],
- outputs=[],
- )
-
- check.click(
- fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
- _js="extensions_check",
- inputs=[info, extensions_disabled_list],
- outputs=[extensions_table, info],
- )
-
- with gr.TabItem("Available"):
- with gr.Row():
- refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
- extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
- install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
-
- with gr.Row():
- hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
- sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
-
- install_result = gr.HTML()
- available_extensions_table = gr.HTML()
-
- refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
- inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
- )
-
- install_extension_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install, hide_tags, sort_column],
- outputs=[available_extensions_table, extensions_table, install_result],
- )
-
- hide_tags.change(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
- inputs=[hide_tags, sort_column],
- outputs=[available_extensions_table, install_result]
- )
-
- sort_column.change(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
- inputs=[hide_tags, sort_column],
- outputs=[available_extensions_table, install_result]
- )
-
- with gr.TabItem("Install from URL"):
- install_url = gr.Text(label="URL for extension's git repository")
- install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
- install_button = gr.Button(value="Install", variant="primary")
- install_result = gr.HTML(elem_id="extension_install_result")
-
- install_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
- inputs=[install_dirname, install_url],
- outputs=[extensions_table, install_result],
- )
-
- return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
deleted file mode 100644
index 71f1d81f23eb92a1ed2fdfd5a2d986ffd13e8e9f..0000000000000000000000000000000000000000
--- a/modules/ui_extra_networks.py
+++ /dev/null
@@ -1,251 +0,0 @@
-import glob
-import os.path
-import urllib.parse
-from pathlib import Path
-
-from modules import shared
-import gradio as gr
-import json
-import html
-
-from modules.generation_parameters_copypaste import image_from_url_text
-
-extra_pages = []
-allowed_dirs = set()
-
-
-def register_page(page):
- """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
-
- extra_pages.append(page)
- allowed_dirs.clear()
- allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
-
-
-def add_pages_to_demo(app):
- def fetch_file(filename: str = ""):
- from starlette.responses import FileResponse
-
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
- raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
-
- ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
-
- # would profit from returning 304
- return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
-
- app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
-
-
-class ExtraNetworksPage:
- def __init__(self, title):
- self.title = title
- self.name = title.lower()
- self.card_page = shared.html("extra-networks-card.html")
- self.allow_negative_prompt = False
-
- def refresh(self):
- pass
-
- def link_preview(self, filename):
- return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
-
- def search_terms_from_path(self, filename, possible_directories=None):
- abspath = os.path.abspath(filename)
-
- for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
- parentdir = os.path.abspath(parentdir)
- if abspath.startswith(parentdir):
- return abspath[len(parentdir):].replace('\\', '/')
-
- return ""
-
- def create_html(self, tabname):
- view = shared.opts.extra_networks_default_view
- items_html = ''
-
- subdirs = {}
- for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
- if not os.path.isdir(x):
- continue
-
- subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
- while subdir.startswith("/"):
- subdir = subdir[1:]
-
- is_empty = len(os.listdir(x)) == 0
- if not is_empty and not subdir.endswith("/"):
- subdir = subdir + "/"
-
- subdirs[subdir] = 1
-
- if subdirs:
- subdirs = {"": 1, **subdirs}
-
- subdirs_html = "".join([f"""
-
-{html.escape(subdir if subdir!="" else "all")}
-
-""" for subdir in subdirs])
-
- for item in self.list_items():
- items_html += self.create_html_for_item(item, tabname)
-
- if items_html == '':
- dirs = "".join([f"{x} " for x in self.allowed_directories_for_previews()])
- items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
-
- self_name_id = self.name.replace(" ", "_")
-
- res = f"""
-
-
-"""
-
- return res
-
- def list_items(self):
- raise NotImplementedError()
-
- def allowed_directories_for_previews(self):
- return []
-
- def create_html_for_item(self, item, tabname):
- preview = item.get("preview", None)
-
- onclick = item.get("onclick", None)
- if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
-
- args = {
- "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
- "prompt": item.get("prompt", None),
- "tabname": json.dumps(tabname),
- "local_preview": json.dumps(item["local_preview"]),
- "name": item["name"],
- "card_clicked": onclick,
- "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
- "search_term": item.get("search_term", ""),
- }
-
- return self.card_page.format(**args)
-
-
-def intialize():
- extra_pages.clear()
-
-
-class ExtraNetworksUi:
- def __init__(self):
- self.pages = None
- self.stored_extra_pages = None
-
- self.button_save_preview = None
- self.preview_target_filename = None
-
- self.tabname = None
-
-
-def pages_in_preferred_order(pages):
- tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
-
- def tab_name_score(name):
- name = name.lower()
- for i, possible_match in enumerate(tab_order):
- if possible_match in name:
- return i
-
- return len(pages)
-
- tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
-
- return sorted(pages, key=lambda x: tab_scores[x.name])
-
-
-def create_ui(container, button, tabname):
- ui = ExtraNetworksUi()
- ui.pages = []
- ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
- ui.tabname = tabname
-
- with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
- for page in ui.stored_extra_pages:
- with gr.Tab(page.title):
- page_elem = gr.HTML(page.create_html(ui.tabname))
- ui.pages.append(page_elem)
-
- filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
- button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
-
- ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
- ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
-
- def toggle_visibility(is_visible):
- is_visible = not is_visible
- return is_visible, gr.update(visible=is_visible)
-
- state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
- button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
-
- def refresh():
- res = []
-
- for pg in ui.stored_extra_pages:
- pg.refresh()
- res.append(pg.create_html(ui.tabname))
-
- return res
-
- button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
-
- return ui
-
-
-def path_is_parent(parent_path, child_path):
- parent_path = os.path.abspath(parent_path)
- child_path = os.path.abspath(child_path)
-
- return child_path.startswith(parent_path)
-
-
-def setup_ui(ui, gallery):
- def save_preview(index, images, filename):
- if len(images) == 0:
- print("There is no image in gallery to save as a preview.")
- return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
-
- index = int(index)
- index = 0 if index < 0 else index
- index = len(images) - 1 if index >= len(images) else index
-
- img_info = images[index if index >= 0 else 0]
- image = image_from_url_text(img_info)
-
- is_allowed = False
- for extra_page in ui.stored_extra_pages:
- if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
- is_allowed = True
- break
-
- assert is_allowed, f'writing to {filename} is not allowed'
-
- image.save(filename)
-
- return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
-
- ui.button_save_preview.click(
- fn=save_preview,
- _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
- inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
- outputs=[*ui.pages]
- )
-
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
deleted file mode 100644
index 04097a7945911b8f73be49963325134da929d10f..0000000000000000000000000000000000000000
--- a/modules/ui_extra_networks_checkpoints.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import html
-import json
-import os
-import urllib.parse
-
-from modules import shared, ui_extra_networks, sd_models
-
-
-class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Checkpoints')
-
- def refresh(self):
- shared.refresh_checkpoints()
-
- def list_items(self):
- checkpoint: sd_models.CheckpointInfo
- for name, checkpoint in sd_models.checkpoints_list.items():
- path, ext = os.path.splitext(checkpoint.filename)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
- yield {
- "name": checkpoint.name_for_extra,
- "filename": path,
- "preview": preview,
- "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
- "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": path + ".png",
- }
-
- def allowed_directories_for_previews(self):
- return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
-
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
deleted file mode 100644
index 57851088734daa77c904cccf120b2c00c44c994e..0000000000000000000000000000000000000000
--- a/modules/ui_extra_networks_hypernets.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import json
-import os
-
-from modules import shared, ui_extra_networks
-
-
-class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Hypernetworks')
-
- def refresh(self):
- shared.reload_hypernetworks()
-
- def list_items(self):
- for name, path in shared.hypernetworks.items():
- path, ext = os.path.splitext(path)
- previews = [path + ".png", path + ".preview.png"]
-
- preview = None
- for file in previews:
- if os.path.isfile(file):
- preview = self.link_preview(file)
- break
-
- yield {
- "name": name,
- "filename": path,
- "preview": preview,
- "search_term": self.search_terms_from_path(path),
- "prompt": json.dumps(f""),
- "local_preview": path + ".png",
- }
-
- def allowed_directories_for_previews(self):
- return [shared.cmd_opts.hypernetwork_dir]
-
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
deleted file mode 100644
index bb64eb81e294ebb6a94adf94d069a013e37331f1..0000000000000000000000000000000000000000
--- a/modules/ui_extra_networks_textual_inversion.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import json
-import os
-
-from modules import ui_extra_networks, sd_hijack
-
-
-class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
- def __init__(self):
- super().__init__('Textual Inversion')
- self.allow_negative_prompt = True
-
- def refresh(self):
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
-
- def list_items(self):
- for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
- path, ext = os.path.splitext(embedding.filename)
- preview_file = path + ".preview.png"
-
- preview = None
- if os.path.isfile(preview_file):
- preview = self.link_preview(preview_file)
-
- yield {
- "name": embedding.name,
- "filename": embedding.filename,
- "preview": preview,
- "search_term": self.search_terms_from_path(embedding.filename),
- "prompt": json.dumps(embedding.name),
- "local_preview": path + ".preview.png",
- }
-
- def allowed_directories_for_previews(self):
- return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
deleted file mode 100644
index b418d9553059d946c34b4e251062c46fe098fd75..0000000000000000000000000000000000000000
--- a/modules/ui_postprocessing.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import gradio as gr
-from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
-import modules.generation_parameters_copypaste as parameters_copypaste
-
-
-def create_ui():
- tab_index = gr.State(value=0)
-
- with gr.Row().style(equal_height=False, variant='compact'):
- with gr.Column(variant='compact'):
- with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
-
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
-
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
- show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
-
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
- script_inputs = scripts.scripts_postproc.setup_ui()
-
- with gr.Column():
- result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
-
- tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
- tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
- tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
-
- submit.click(
- fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
- inputs=[
- tab_index,
- extras_image,
- image_batch,
- extras_batch_input_dir,
- extras_batch_output_dir,
- show_extras_results,
- *script_inputs
- ],
- outputs=[
- result_images,
- html_info_x,
- html_info,
- ]
- )
-
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
- extras_image.change(
- fn=scripts.scripts_postproc.image_changed,
- inputs=[], outputs=[]
- )
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
deleted file mode 100644
index 21945235ef360823fafbc92c358078f343e874da..0000000000000000000000000000000000000000
--- a/modules/ui_tempdir.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import os
-import tempfile
-from collections import namedtuple
-from pathlib import Path
-
-import gradio as gr
-
-from PIL import PngImagePlugin
-
-from modules import shared
-
-
-Savedfile = namedtuple("Savedfile", ["name"])
-
-
-def register_tmp_file(gradio, filename):
- if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
- gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
-
- if hasattr(gradio, 'temp_dirs'): # gradio 3.9
- gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
-
-
-def check_tmp_file(gradio, filename):
- if hasattr(gradio, 'temp_file_sets'):
- return any([filename in fileset for fileset in gradio.temp_file_sets])
-
- if hasattr(gradio, 'temp_dirs'):
- return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
-
- return False
-
-
-def save_pil_to_file(pil_image, dir=None):
- already_saved_as = getattr(pil_image, 'already_saved_as', None)
- if already_saved_as and os.path.isfile(already_saved_as):
- register_tmp_file(shared.demo, already_saved_as)
-
- file_obj = Savedfile(already_saved_as)
- return file_obj
-
- if shared.opts.temp_dir != "":
- dir = shared.opts.temp_dir
-
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in pil_image.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
-
- file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
- pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
-
-
-# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
-
-
-def on_tmpdir_changed():
- if shared.opts.temp_dir == "" or shared.demo is None:
- return
-
- os.makedirs(shared.opts.temp_dir, exist_ok=True)
-
- register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
-
-
-def cleanup_tmpdr():
- temp_dir = shared.opts.temp_dir
- if temp_dir == "" or not os.path.isdir(temp_dir):
- return
-
- for root, dirs, files in os.walk(temp_dir, topdown=False):
- for name in files:
- _, extension = os.path.splitext(name)
- if extension != ".png":
- continue
-
- filename = os.path.join(root, name)
- os.remove(filename)
diff --git a/modules/upscaler.py b/modules/upscaler.py
deleted file mode 100644
index e2eaa7308af0091b6e8f407e889b2e446679e149..0000000000000000000000000000000000000000
--- a/modules/upscaler.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import os
-from abc import abstractmethod
-
-import PIL
-import numpy as np
-import torch
-from PIL import Image
-
-import modules.shared
-from modules import modelloader, shared
-
-LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
-NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
-
-
-class Upscaler:
- name = None
- model_path = None
- model_name = None
- model_url = None
- enable = True
- filter = None
- model = None
- user_path = None
- scalers: []
- tile = True
-
- def __init__(self, create_dirs=False):
- self.mod_pad_h = None
- self.tile_size = modules.shared.opts.ESRGAN_tile
- self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
- self.device = modules.shared.device
- self.img = None
- self.output = None
- self.scale = 1
- self.half = not modules.shared.cmd_opts.no_half
- self.pre_pad = 0
- self.mod_scale = None
-
- if self.model_path is None and self.name:
- self.model_path = os.path.join(shared.models_path, self.name)
- if self.model_path and create_dirs:
- os.makedirs(self.model_path, exist_ok=True)
-
- try:
- import cv2
- self.can_tile = True
- except:
- pass
-
- @abstractmethod
- def do_upscale(self, img: PIL.Image, selected_model: str):
- return img
-
- def upscale(self, img: PIL.Image, scale, selected_model: str = None):
- self.scale = scale
- dest_w = int(img.width * scale)
- dest_h = int(img.height * scale)
-
- for i in range(3):
- shape = (img.width, img.height)
-
- img = self.do_upscale(img, selected_model)
-
- if shape == (img.width, img.height):
- break
-
- if img.width >= dest_w and img.height >= dest_h:
- break
-
- if img.width != dest_w or img.height != dest_h:
- img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
-
- return img
-
- @abstractmethod
- def load_model(self, path: str):
- pass
-
- def find_models(self, ext_filter=None) -> list:
- return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
-
- def update_status(self, prompt):
- print(f"\nextras: {prompt}", file=shared.progress_print_out)
-
-
-class UpscalerData:
- name = None
- data_path = None
- scale: int = 4
- scaler: Upscaler = None
- model: None
-
- def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
- self.name = name
- self.data_path = path
- self.local_data_path = path
- self.scaler = upscaler
- self.scale = scale
- self.model = model
-
-
-class UpscalerNone(Upscaler):
- name = "None"
- scalers = []
-
- def load_model(self, path):
- pass
-
- def do_upscale(self, img, selected_model=None):
- return img
-
- def __init__(self, dirname=None):
- super().__init__(False)
- self.scalers = [UpscalerData("None", None, self)]
-
-
-class UpscalerLanczos(Upscaler):
- scalers = []
-
- def do_upscale(self, img, selected_model=None):
- return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
-
- def load_model(self, _):
- pass
-
- def __init__(self, dirname=None):
- super().__init__(False)
- self.name = "Lanczos"
- self.scalers = [UpscalerData("Lanczos", None, self)]
-
-
-class UpscalerNearest(Upscaler):
- scalers = []
-
- def do_upscale(self, img, selected_model=None):
- return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
-
- def load_model(self, _):
- pass
-
- def __init__(self, dirname=None):
- super().__init__(False)
- self.name = "Nearest"
- self.scalers = [UpscalerData("Nearest", None, self)]
diff --git a/modules/xlmr.py b/modules/xlmr.py
deleted file mode 100644
index beab3fdf55e7bcffd96f3b36679e7a90c0f390dc..0000000000000000000000000000000000000000
--- a/modules/xlmr.py
+++ /dev/null
@@ -1,137 +0,0 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
-import torch.nn as nn
-import torch
-from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
-from transformers import XLMRobertaModel,XLMRobertaTokenizer
-from typing import Optional
-
-class BertSeriesConfig(BertConfig):
- def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
-
- super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
- self.project_dim = project_dim
- self.pooler_fn = pooler_fn
- self.learn_encoder = learn_encoder
-
-class RobertaSeriesConfig(XLMRobertaConfig):
- def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
- super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
- self.project_dim = project_dim
- self.pooler_fn = pooler_fn
- self.learn_encoder = learn_encoder
-
-
-class BertSeriesModelWithTransformation(BertPreTrainedModel):
-
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
- _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
- config_class = BertSeriesConfig
-
- def __init__(self, config=None, **kargs):
- # modify initialization for autoloading
- if config is None:
- config = XLMRobertaConfig()
- config.attention_probs_dropout_prob= 0.1
- config.bos_token_id=0
- config.eos_token_id=2
- config.hidden_act='gelu'
- config.hidden_dropout_prob=0.1
- config.hidden_size=1024
- config.initializer_range=0.02
- config.intermediate_size=4096
- config.layer_norm_eps=1e-05
- config.max_position_embeddings=514
-
- config.num_attention_heads=16
- config.num_hidden_layers=24
- config.output_past=True
- config.pad_token_id=1
- config.position_embedding_type= "absolute"
-
- config.type_vocab_size= 1
- config.use_cache=True
- config.vocab_size= 250002
- config.project_dim = 768
- config.learn_encoder = False
- super().__init__(config)
- self.roberta = XLMRobertaModel(config)
- self.transformation = nn.Linear(config.hidden_size,config.project_dim)
- self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
- self.pooler = lambda x: x[:,0]
- self.post_init()
-
- def encode(self,c):
- device = next(self.parameters()).device
- text = self.tokenizer(c,
- truncation=True,
- max_length=77,
- return_length=False,
- return_overflowing_tokens=False,
- padding="max_length",
- return_tensors="pt")
- text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
- text["attention_mask"] = torch.tensor(
- text['attention_mask']).to(device)
- features = self(**text)
- return features['projection_state']
-
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- token_type_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- ) :
- r"""
- """
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-
- outputs = self.roberta(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=True,
- return_dict=return_dict,
- )
-
- # last module outputs
- sequence_output = outputs[0]
-
-
- # project every module
- sequence_output_ln = self.pre_LN(sequence_output)
-
- # pooler
- pooler_output = self.pooler(sequence_output_ln)
- pooler_output = self.transformation(pooler_output)
- projection_state = self.transformation(outputs.last_hidden_state)
-
- return {
- 'pooler_output':pooler_output,
- 'last_hidden_state':outputs.last_hidden_state,
- 'hidden_states':outputs.hidden_states,
- 'attentions':outputs.attentions,
- 'projection_state':projection_state,
- 'sequence_out': sequence_output
- }
-
-
-class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
- base_model_prefix = 'roberta'
- config_class= RobertaSeriesConfig
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 6d53f08933a7f02e6266f9ae4fc8f23dfa119ecd..0000000000000000000000000000000000000000
--- a/requirements.txt
+++ /dev/null
@@ -1,32 +0,0 @@
-blendmodes
-accelerate
-basicsr
-fonts
-font-roboto
-gfpgan
-gradio==3.16.2
-invisible-watermark
-numpy
-omegaconf
-opencv-contrib-python
-requests
-piexif
-Pillow
-pytorch_lightning==1.7.7
-realesrgan
-scikit-image>=0.19
-timm==0.4.12
-transformers==4.25.1
-torch
-einops
-jsonmerge
-clean-fid
-resize-right
-torchdiffeq
-kornia
-lark
-inflection
-GitPython
-torchsde
-safetensors
-psutil
diff --git a/requirements_versions.txt b/requirements_versions.txt
deleted file mode 100644
index 331d0fe86513ae9e0dbe3cd8299878de78281e90..0000000000000000000000000000000000000000
--- a/requirements_versions.txt
+++ /dev/null
@@ -1,30 +0,0 @@
-blendmodes==2022
-transformers==4.25.1
-accelerate==0.12.0
-basicsr==1.4.2
-gfpgan==1.3.8
-gradio==3.16.2
-numpy==1.23.3
-Pillow==9.4.0
-realesrgan==0.3.0
-torch
-omegaconf==2.2.3
-pytorch_lightning==1.7.6
-scikit-image==0.19.2
-fonts
-font-roboto
-timm==0.6.7
-piexif==1.1.3
-einops==0.4.1
-jsonmerge==1.8.0
-clean-fid==0.1.29
-resize-right==0.0.2
-torchdiffeq==0.2.3
-kornia==0.6.7
-lark==1.1.2
-inflection==0.5.1
-GitPython==3.1.27
-torchsde==0.2.5
-safetensors==0.2.7
-httpcore<=0.15
-fastapi==0.90.1
diff --git a/screenshot.png b/screenshot.png
deleted file mode 100644
index 47a1be4ec43e315f3e47139b10b0f9a8045904f3..0000000000000000000000000000000000000000
Binary files a/screenshot.png and /dev/null differ
diff --git a/script.js b/script.js
deleted file mode 100644
index 97e0bfcf9fa3cc6b5823d86b0949ab9c947c6418..0000000000000000000000000000000000000000
--- a/script.js
+++ /dev/null
@@ -1,102 +0,0 @@
-function gradioApp() {
- const elems = document.getElementsByTagName('gradio-app')
- const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
- return !!gradioShadowRoot ? gradioShadowRoot : document;
-}
-
-function get_uiCurrentTab() {
- return gradioApp().querySelector('#tabs button:not(.border-transparent)')
-}
-
-function get_uiCurrentTabContent() {
- return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
-}
-
-uiUpdateCallbacks = []
-uiLoadedCallbacks = []
-uiTabChangeCallbacks = []
-optionsChangedCallbacks = []
-let uiCurrentTab = null
-
-function onUiUpdate(callback){
- uiUpdateCallbacks.push(callback)
-}
-function onUiLoaded(callback){
- uiLoadedCallbacks.push(callback)
-}
-function onUiTabChange(callback){
- uiTabChangeCallbacks.push(callback)
-}
-function onOptionsChanged(callback){
- optionsChangedCallbacks.push(callback)
-}
-
-function runCallback(x, m){
- try {
- x(m)
- } catch (e) {
- (console.error || console.log).call(console, e.message, e);
- }
-}
-function executeCallbacks(queue, m) {
- queue.forEach(function(x){runCallback(x, m)})
-}
-
-var executedOnLoaded = false;
-
-document.addEventListener("DOMContentLoaded", function() {
- var mutationObserver = new MutationObserver(function(m){
- if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
- executedOnLoaded = true;
- executeCallbacks(uiLoadedCallbacks);
- }
-
- executeCallbacks(uiUpdateCallbacks, m);
- const newTab = get_uiCurrentTab();
- if ( newTab && ( newTab !== uiCurrentTab ) ) {
- uiCurrentTab = newTab;
- executeCallbacks(uiTabChangeCallbacks);
- }
- });
- mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
-});
-
-/**
- * Add a ctrl+enter as a shortcut to start a generation
- */
-document.addEventListener('keydown', function(e) {
- var handled = false;
- if (e.key !== undefined) {
- if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
- } else if (e.keyCode !== undefined) {
- if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
- }
- if (handled) {
- button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
- if (button) {
- button.click();
- }
- e.preventDefault();
- }
-})
-
-/**
- * checks that a UI element is not in another hidden element or tab content
- */
-function uiElementIsVisible(el) {
- let isVisible = !el.closest('.\\!hidden');
- if ( ! isVisible ) {
- return false;
- }
-
- while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
- if ( ! isVisible ) {
- return false;
- } else if ( el.parentElement ) {
- el = el.parentElement
- } else {
- break;
- }
- }
- return isVisible;
-}
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
deleted file mode 100644
index d29113e671070d10bdb7d7a23059255681fa0ef2..0000000000000000000000000000000000000000
--- a/scripts/custom_code.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import modules.scripts as scripts
-import gradio as gr
-
-from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
-
-class Script(scripts.Script):
-
- def title(self):
- return "Custom code"
-
- def show(self, is_img2img):
- return cmd_opts.allow_code
-
- def ui(self, is_img2img):
- code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
-
- return [code]
-
-
- def run(self, p, code):
- assert cmd_opts.allow_code, '--allow-code option must be enabled'
-
- display_result_data = [[], -1, ""]
-
- def display(imgs, s=display_result_data[1], i=display_result_data[2]):
- display_result_data[0] = imgs
- display_result_data[1] = s
- display_result_data[2] = i
-
- from types import ModuleType
- compiled = compile(code, '', 'exec')
- module = ModuleType("testmodule")
- module.__dict__.update(globals())
- module.p = p
- module.display = display
- exec(compiled, module.__dict__)
-
- return Processed(p, *display_result_data)
-
-
\ No newline at end of file
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
deleted file mode 100644
index 2572443f97478ad984e8128e3749939946e7f59b..0000000000000000000000000000000000000000
--- a/scripts/img2imgalt.py
+++ /dev/null
@@ -1,216 +0,0 @@
-from collections import namedtuple
-
-import numpy as np
-from tqdm import trange
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
-from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
-
-import torch
-import k_diffusion as K
-
-from PIL import Image
-from torch import autocast
-from einops import rearrange, repeat
-
-
-def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
- x = p.init_latent
-
- s_in = x.new_ones([x.shape[0]])
- dnw = K.external.CompVisDenoiser(shared.sd_model)
- sigmas = dnw.get_sigmas(steps).flip(0)
-
- shared.state.sampling_steps = steps
-
- for i in trange(1, len(sigmas)):
- shared.state.sampling_step += 1
-
- x_in = torch.cat([x] * 2)
- sigma_in = torch.cat([sigmas[i] * s_in] * 2)
- cond_in = torch.cat([uncond, cond])
-
- image_conditioning = torch.cat([p.image_conditioning] * 2)
- cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
-
- c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
- t = dnw.sigma_to_t(sigma_in)
-
- eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
- denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
-
- denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
-
- d = (x - denoised) / sigmas[i]
- dt = sigmas[i] - sigmas[i - 1]
-
- x = x + d * dt
-
- sd_samplers_common.store_latent(x)
-
- # This shouldn't be necessary, but solved some VRAM issues
- del x_in, sigma_in, cond_in, c_out, c_in, t,
- del eps, denoised_uncond, denoised_cond, denoised, d, dt
-
- shared.state.nextjob()
-
- return x / x.std()
-
-
-Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
-
-
-# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
-def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
- x = p.init_latent
-
- s_in = x.new_ones([x.shape[0]])
- dnw = K.external.CompVisDenoiser(shared.sd_model)
- sigmas = dnw.get_sigmas(steps).flip(0)
-
- shared.state.sampling_steps = steps
-
- for i in trange(1, len(sigmas)):
- shared.state.sampling_step += 1
-
- x_in = torch.cat([x] * 2)
- sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
- cond_in = torch.cat([uncond, cond])
-
- image_conditioning = torch.cat([p.image_conditioning] * 2)
- cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
-
- c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
-
- if i == 1:
- t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
- else:
- t = dnw.sigma_to_t(sigma_in)
-
- eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
- denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
-
- denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
-
- if i == 1:
- d = (x - denoised) / (2 * sigmas[i])
- else:
- d = (x - denoised) / sigmas[i - 1]
-
- dt = sigmas[i] - sigmas[i - 1]
- x = x + d * dt
-
- sd_samplers_common.store_latent(x)
-
- # This shouldn't be necessary, but solved some VRAM issues
- del x_in, sigma_in, cond_in, c_out, c_in, t,
- del eps, denoised_uncond, denoised_cond, denoised, d, dt
-
- shared.state.nextjob()
-
- return x / sigmas[-1]
-
-
-class Script(scripts.Script):
- def __init__(self):
- self.cache = None
-
- def title(self):
- return "img2img alternative test"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- info = gr.Markdown('''
- * `CFG Scale` should be 2 or lower.
- ''')
-
- override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
-
- override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
- original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
- original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))
-
- override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
- st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
-
- override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))
-
- cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
- randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
- sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
-
- return [
- info,
- override_sampler,
- override_prompt, original_prompt, original_negative_prompt,
- override_steps, st,
- override_strength,
- cfg, randomness, sigma_adjustment,
- ]
-
- def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
- # Override
- if override_sampler:
- p.sampler_name = "Euler"
- if override_prompt:
- p.prompt = original_prompt
- p.negative_prompt = original_negative_prompt
- if override_steps:
- p.steps = st
- if override_strength:
- p.denoising_strength = 1.0
-
- def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- lat = (p.init_latent.cpu().numpy() * 10).astype(int)
-
- same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
- and self.cache.original_prompt == original_prompt \
- and self.cache.original_negative_prompt == original_negative_prompt \
- and self.cache.sigma_adjustment == sigma_adjustment
- same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
-
- if same_everything:
- rec_noise = self.cache.noise
- else:
- shared.state.job_count += 1
- cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
- uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
- if sigma_adjustment:
- rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
- else:
- rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
- self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
-
- rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
-
- combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
-
- sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
-
- sigmas = sampler.model_wrap.get_sigmas(p.steps)
-
- noise_dt = combined_noise - (p.init_latent / sigmas[0])
-
- p.seed = p.seed + 1
-
- return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
-
- p.sample = sample_extra
-
- p.extra_generation_params["Decode prompt"] = original_prompt
- p.extra_generation_params["Decode negative prompt"] = original_negative_prompt
- p.extra_generation_params["Decode CFG scale"] = cfg
- p.extra_generation_params["Decode steps"] = st
- p.extra_generation_params["Randomness"] = randomness
- p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
-
- processed = processing.process_images(p)
-
- return processed
-
diff --git a/scripts/loopback.py b/scripts/loopback.py
deleted file mode 100644
index ec1f85e5891cc97d15ff06faaf749e29d98e362d..0000000000000000000000000000000000000000
--- a/scripts/loopback.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import numpy as np
-from tqdm import trange
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import processing, shared, sd_samplers, images
-from modules.processing import Processed
-from modules.sd_samplers import samplers
-from modules.shared import opts, cmd_opts, state
-from modules import deepbooru
-
-
-class Script(scripts.Script):
- def title(self):
- return "Loopback"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
- denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
- append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
-
- return [loops, denoising_strength_change_factor, append_interrogation]
-
- def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
- processing.fix_seed(p)
- batch_count = p.n_iter
- p.extra_generation_params = {
- "Denoising strength change factor": denoising_strength_change_factor,
- }
-
- p.batch_size = 1
- p.n_iter = 1
-
- output_images, info = None, None
- initial_seed = None
- initial_info = None
-
- grids = []
- all_images = []
- original_init_image = p.init_images
- original_prompt = p.prompt
- state.job_count = loops * batch_count
-
- initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
-
- for n in range(batch_count):
- history = []
-
- # Reset to original init image at the start of each batch
- p.init_images = original_init_image
-
- for i in range(loops):
- p.n_iter = 1
- p.batch_size = 1
- p.do_not_save_grid = True
-
- if opts.img2img_color_correction:
- p.color_corrections = initial_color_corrections
-
- if append_interrogation != "None":
- p.prompt = original_prompt + ", " if original_prompt != "" else ""
- if append_interrogation == "CLIP":
- p.prompt += shared.interrogator.interrogate(p.init_images[0])
- elif append_interrogation == "DeepBooru":
- p.prompt += deepbooru.model.tag(p.init_images[0])
-
- state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
-
- processed = processing.process_images(p)
-
- if initial_seed is None:
- initial_seed = processed.seed
- initial_info = processed.info
-
- init_img = processed.images[0]
-
- p.init_images = [init_img]
- p.seed = processed.seed + 1
- p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
- history.append(processed.images[0])
-
- grid = images.image_grid(history, rows=1)
- if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
-
- grids.append(grid)
- all_images += history
-
- if opts.return_grid:
- all_images = grids + all_images
-
- processed = Processed(p, all_images, initial_seed, initial_info)
-
- return processed
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
deleted file mode 100644
index 0906da6ae7da2281e83a517f2278f6cf1315c518..0000000000000000000000000000000000000000
--- a/scripts/outpainting_mk_2.py
+++ /dev/null
@@ -1,283 +0,0 @@
-import math
-
-import numpy as np
-import skimage
-
-import modules.scripts as scripts
-import gradio as gr
-from PIL import Image, ImageDraw
-
-from modules import images, processing, devices
-from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
-
-
-# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
-def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
- # helper fft routines that keep ortho normalization and auto-shift before and after fft
- def _fft2(data):
- if data.ndim > 2: # has channels
- out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
- for c in range(data.shape[2]):
- c_data = data[:, :, c]
- out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
- out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
- else: # one channel
- out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
- out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
- out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
-
- return out_fft
-
- def _ifft2(data):
- if data.ndim > 2: # has channels
- out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
- for c in range(data.shape[2]):
- c_data = data[:, :, c]
- out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
- out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
- else: # one channel
- out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
- out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
- out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
-
- return out_ifft
-
- def _get_gaussian_window(width, height, std=3.14, mode=0):
- window_scale_x = float(width / min(width, height))
- window_scale_y = float(height / min(width, height))
-
- window = np.zeros((width, height))
- x = (np.arange(width) / width * 2. - 1.) * window_scale_x
- for y in range(height):
- fy = (y / height * 2. - 1.) * window_scale_y
- if mode == 0:
- window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
- else:
- window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
-
- return window
-
- def _get_masked_window_rgb(np_mask_grey, hardness=1.):
- np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
- if hardness != 1.:
- hardened = np_mask_grey[:] ** hardness
- else:
- hardened = np_mask_grey[:]
- for c in range(3):
- np_mask_rgb[:, :, c] = hardened[:]
- return np_mask_rgb
-
- width = _np_src_image.shape[0]
- height = _np_src_image.shape[1]
- num_channels = _np_src_image.shape[2]
-
- np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
- np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
- img_mask = np_mask_grey > 1e-6
- ref_mask = np_mask_grey < 1e-3
-
- windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
- windowed_image /= np.max(windowed_image)
- windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
-
- src_fft = _fft2(windowed_image) # get feature statistics from masked src img
- src_dist = np.absolute(src_fft)
- src_phase = src_fft / src_dist
-
- # create a generator with a static seed to make outpainting deterministic / only follow global seed
- rng = np.random.default_rng(0)
-
- noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
- noise_rgb = rng.random((width, height, num_channels))
- noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
- noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
- for c in range(num_channels):
- noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
-
- noise_fft = _fft2(noise_rgb)
- for c in range(num_channels):
- noise_fft[:, :, c] *= noise_window
- noise_rgb = np.real(_ifft2(noise_fft))
- shaped_noise_fft = _fft2(noise_rgb)
- shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
-
- brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
- contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
-
- # scikit-image is used for histogram matching, very convenient!
- shaped_noise = np.real(_ifft2(shaped_noise_fft))
- shaped_noise -= np.min(shaped_noise)
- shaped_noise /= np.max(shaped_noise)
- shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
- shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
-
- matched_noise = shaped_noise[:]
-
- return np.clip(matched_noise, 0., 1.)
-
-
-
-class Script(scripts.Script):
- def title(self):
- return "Outpainting mk2"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- if not is_img2img:
- return None
-
- info = gr.HTML("Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8
")
-
- pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
- direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
- noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
- color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
-
- return [info, pixels, mask_blur, direction, noise_q, color_variation]
-
- def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
- initial_seed_and_info = [None, None]
-
- process_width = p.width
- process_height = p.height
-
- p.mask_blur = mask_blur*4
- p.inpaint_full_res = False
- p.inpainting_fill = 1
- p.do_not_save_samples = True
- p.do_not_save_grid = True
-
- left = pixels if "left" in direction else 0
- right = pixels if "right" in direction else 0
- up = pixels if "up" in direction else 0
- down = pixels if "down" in direction else 0
-
- init_img = p.init_images[0]
- target_w = math.ceil((init_img.width + left + right) / 64) * 64
- target_h = math.ceil((init_img.height + up + down) / 64) * 64
-
- if left > 0:
- left = left * (target_w - init_img.width) // (left + right)
-
- if right > 0:
- right = target_w - init_img.width - left
-
- if up > 0:
- up = up * (target_h - init_img.height) // (up + down)
-
- if down > 0:
- down = target_h - init_img.height - up
-
- def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
- is_horiz = is_left or is_right
- is_vert = is_top or is_bottom
- pixels_horiz = expand_pixels if is_horiz else 0
- pixels_vert = expand_pixels if is_vert else 0
-
- images_to_process = []
- output_images = []
- for n in range(count):
- res_w = init[n].width + pixels_horiz
- res_h = init[n].height + pixels_vert
- process_res_w = math.ceil(res_w / 64) * 64
- process_res_h = math.ceil(res_h / 64) * 64
-
- img = Image.new("RGB", (process_res_w, process_res_h))
- img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
- mask = Image.new("RGB", (process_res_w, process_res_h), "white")
- draw = ImageDraw.Draw(mask)
- draw.rectangle((
- expand_pixels + mask_blur if is_left else 0,
- expand_pixels + mask_blur if is_top else 0,
- mask.width - expand_pixels - mask_blur if is_right else res_w,
- mask.height - expand_pixels - mask_blur if is_bottom else res_h,
- ), fill="black")
-
- np_image = (np.asarray(img) / 255.0).astype(np.float64)
- np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
- noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
- output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
-
- target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
- target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
- p.width = target_width if is_horiz else img.width
- p.height = target_height if is_vert else img.height
-
- crop_region = (
- 0 if is_left else output_images[n].width - target_width,
- 0 if is_top else output_images[n].height - target_height,
- target_width if is_left else output_images[n].width,
- target_height if is_top else output_images[n].height,
- )
- mask = mask.crop(crop_region)
- p.image_mask = mask
-
- image_to_process = output_images[n].crop(crop_region)
- images_to_process.append(image_to_process)
-
- p.init_images = images_to_process
-
- latent_mask = Image.new("RGB", (p.width, p.height), "white")
- draw = ImageDraw.Draw(latent_mask)
- draw.rectangle((
- expand_pixels + mask_blur * 2 if is_left else 0,
- expand_pixels + mask_blur * 2 if is_top else 0,
- mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
- mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
- ), fill="black")
- p.latent_mask = latent_mask
-
- proc = process_images(p)
-
- if initial_seed_and_info[0] is None:
- initial_seed_and_info[0] = proc.seed
- initial_seed_and_info[1] = proc.info
-
- for n in range(count):
- output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
- output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
-
- return output_images
-
- batch_count = p.n_iter
- batch_size = p.batch_size
- p.n_iter = 1
- state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
- all_processed_images = []
-
- for i in range(batch_count):
- imgs = [init_img] * batch_size
- state.job = f"Batch {i + 1} out of {batch_count}"
-
- if left > 0:
- imgs = expand(imgs, batch_size, left, is_left=True)
- if right > 0:
- imgs = expand(imgs, batch_size, right, is_right=True)
- if up > 0:
- imgs = expand(imgs, batch_size, up, is_top=True)
- if down > 0:
- imgs = expand(imgs, batch_size, down, is_bottom=True)
-
- all_processed_images += imgs
-
- all_images = all_processed_images
-
- combined_grid_image = images.image_grid(all_processed_images)
- unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
- if opts.return_grid and not unwanted_grid_because_of_img_count:
- all_images = [combined_grid_image] + all_processed_images
-
- res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
-
- if opts.samples_save:
- for img in all_processed_images:
- images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
-
- if opts.grid_save and not unwanted_grid_because_of_img_count:
- images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
-
- return res
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
deleted file mode 100644
index d8feda00acde4e3ff644330cc62eb2dfdee060f6..0000000000000000000000000000000000000000
--- a/scripts/poor_mans_outpainting.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import math
-
-import modules.scripts as scripts
-import gradio as gr
-from PIL import Image, ImageDraw
-
-from modules import images, processing, devices
-from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
-
-
-class Script(scripts.Script):
- def title(self):
- return "Poor man's outpainting"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- if not is_img2img:
- return None
-
- pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
- direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
-
- return [pixels, mask_blur, inpainting_fill, direction]
-
- def run(self, p, pixels, mask_blur, inpainting_fill, direction):
- initial_seed = None
- initial_info = None
-
- p.mask_blur = mask_blur * 2
- p.inpainting_fill = inpainting_fill
- p.inpaint_full_res = False
-
- left = pixels if "left" in direction else 0
- right = pixels if "right" in direction else 0
- up = pixels if "up" in direction else 0
- down = pixels if "down" in direction else 0
-
- init_img = p.init_images[0]
- target_w = math.ceil((init_img.width + left + right) / 64) * 64
- target_h = math.ceil((init_img.height + up + down) / 64) * 64
-
- if left > 0:
- left = left * (target_w - init_img.width) // (left + right)
- if right > 0:
- right = target_w - init_img.width - left
-
- if up > 0:
- up = up * (target_h - init_img.height) // (up + down)
-
- if down > 0:
- down = target_h - init_img.height - up
-
- img = Image.new("RGB", (target_w, target_h))
- img.paste(init_img, (left, up))
-
- mask = Image.new("L", (img.width, img.height), "white")
- draw = ImageDraw.Draw(mask)
- draw.rectangle((
- left + (mask_blur * 2 if left > 0 else 0),
- up + (mask_blur * 2 if up > 0 else 0),
- mask.width - right - (mask_blur * 2 if right > 0 else 0),
- mask.height - down - (mask_blur * 2 if down > 0 else 0)
- ), fill="black")
-
- latent_mask = Image.new("L", (img.width, img.height), "white")
- latent_draw = ImageDraw.Draw(latent_mask)
- latent_draw.rectangle((
- left + (mask_blur//2 if left > 0 else 0),
- up + (mask_blur//2 if up > 0 else 0),
- mask.width - right - (mask_blur//2 if right > 0 else 0),
- mask.height - down - (mask_blur//2 if down > 0 else 0)
- ), fill="black")
-
- devices.torch_gc()
-
- grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
- grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
- grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
-
- p.n_iter = 1
- p.batch_size = 1
- p.do_not_save_grid = True
- p.do_not_save_samples = True
-
- work = []
- work_mask = []
- work_latent_mask = []
- work_results = []
-
- for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
- for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
- x, w = tiledata[0:2]
-
- if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
- continue
-
- work.append(tiledata[2])
- work_mask.append(tiledata_mask[2])
- work_latent_mask.append(tiledata_latent_mask[2])
-
- batch_count = len(work)
- print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
-
- state.job_count = batch_count
-
- for i in range(batch_count):
- p.init_images = [work[i]]
- p.image_mask = work_mask[i]
- p.latent_mask = work_latent_mask[i]
-
- state.job = f"Batch {i + 1} out of {batch_count}"
- processed = process_images(p)
-
- if initial_seed is None:
- initial_seed = processed.seed
- initial_info = processed.info
-
- p.seed = processed.seed + 1
- work_results += processed.images
-
-
- image_index = 0
- for y, h, row in grid.tiles:
- for tiledata in row:
- x, w = tiledata[0:2]
-
- if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
- continue
-
- tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
- image_index += 1
-
- combined_image = images.combine_grid(grid)
-
- if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
-
- processed = Processed(p, [combined_image], initial_seed, initial_info)
-
- return processed
-
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
deleted file mode 100644
index a7d80d40e2b946fd35206f8bbc302a3cf476f081..0000000000000000000000000000000000000000
--- a/scripts/postprocessing_codeformer.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from PIL import Image
-import numpy as np
-
-from modules import scripts_postprocessing, codeformer_model
-import gradio as gr
-
-from modules.ui_components import FormRow
-
-
-class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
- name = "CodeFormer"
- order = 3000
-
- def ui(self):
- with FormRow():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
-
- return {
- "codeformer_visibility": codeformer_visibility,
- "codeformer_weight": codeformer_weight,
- }
-
- def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
- if codeformer_visibility == 0:
- return
-
- restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
-
- if codeformer_visibility < 1.0:
- res = Image.blend(pp.image, res, codeformer_visibility)
-
- pp.image = res
- pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
- pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
deleted file mode 100644
index d854f3f7748dd8dec9575eb8914344db86a7f0c0..0000000000000000000000000000000000000000
--- a/scripts/postprocessing_gfpgan.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from PIL import Image
-import numpy as np
-
-from modules import scripts_postprocessing, gfpgan_model
-import gradio as gr
-
-from modules.ui_components import FormRow
-
-
-class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
- name = "GFPGAN"
- order = 2000
-
- def ui(self):
- with FormRow():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
-
- return {
- "gfpgan_visibility": gfpgan_visibility,
- }
-
- def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
- if gfpgan_visibility == 0:
- return
-
- restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
-
- if gfpgan_visibility < 1.0:
- res = Image.blend(pp.image, res, gfpgan_visibility)
-
- pp.image = res
- pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
deleted file mode 100644
index 8842bd91c926be4382bf2e5c2877021e43f37a69..0000000000000000000000000000000000000000
--- a/scripts/postprocessing_upscale.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from PIL import Image
-import numpy as np
-
-from modules import scripts_postprocessing, shared
-import gradio as gr
-
-from modules.ui_components import FormRow
-
-
-upscale_cache = {}
-
-
-class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
- name = "Upscale"
- order = 1000
-
- def ui(self):
- selected_tab = gr.State(value=0)
-
- with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
-
- with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
- with FormRow():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
-
- with FormRow():
- extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
-
- with FormRow():
- extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
-
- tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
- tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
-
- return {
- "upscale_mode": selected_tab,
- "upscale_by": upscaling_resize,
- "upscale_to_width": upscaling_resize_w,
- "upscale_to_height": upscaling_resize_h,
- "upscale_crop": upscaling_crop,
- "upscaler_1_name": extras_upscaler_1,
- "upscaler_2_name": extras_upscaler_2,
- "upscaler_2_visibility": extras_upscaler_2_visibility,
- }
-
- def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
- if upscale_mode == 1:
- upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
- info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}"
- else:
- info["Postprocess upscale by"] = upscale_by
-
- cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- cached_image = upscale_cache.pop(cache_key, None)
-
- if cached_image is not None:
- image = cached_image
- else:
- image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
-
- upscale_cache[cache_key] = image
- if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:
- upscale_cache.pop(next(iter(upscale_cache), None), None)
-
- if upscale_mode == 1 and upscale_crop:
- cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
- cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
- image = cropped
- info["Postprocess crop to"] = f"{image.width}x{image.height}"
-
- return image
-
- def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
- if upscaler_1_name == "None":
- upscaler_1_name = None
-
- upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)
- assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'
-
- if not upscaler1:
- return
-
- if upscaler_2_name == "None":
- upscaler_2_name = None
-
- upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None)
- assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
-
- upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
-
- if upscaler2 and upscaler_2_visibility > 0:
- second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
-
- pp.info[f"Postprocess upscaler 2"] = upscaler2.name
-
- pp.image = upscaled_image
-
- def image_changed(self):
- upscale_cache.clear()
-
-
-class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
- name = "Simple Upscale"
- order = 900
-
- def ui(self):
- with FormRow():
- upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
- upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
-
- return {
- "upscale_by": upscale_by,
- "upscaler_name": upscaler_name,
- }
-
- def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
- if upscaler_name is None or upscaler_name == "None":
- return
-
- upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
- assert upscaler1, f'could not find upscaler named {upscaler_name}'
-
- pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
deleted file mode 100644
index b1c486d44e3c3e4df4dc663e598239a7632b0bf2..0000000000000000000000000000000000000000
--- a/scripts/prompt_matrix.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import math
-from collections import namedtuple
-from copy import copy
-import random
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import images
-from modules.processing import process_images, Processed
-from modules.shared import opts, cmd_opts, state
-import modules.sd_samplers
-
-
-def draw_xy_grid(xs, ys, x_label, y_label, cell):
- res = []
-
- ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
- hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
-
- first_processed = None
-
- state.job_count = len(xs) * len(ys)
-
- for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
- state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
-
- processed = cell(x, y)
- if first_processed is None:
- first_processed = processed
-
- res.append(processed.images[0])
-
- grid = images.image_grid(res, rows=len(ys))
- grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
-
- first_processed.images = [grid]
-
- return first_processed
-
-
-class Script(scripts.Script):
- def title(self):
- return "Prompt matrix"
-
- def ui(self, is_img2img):
- gr.HTML(' ')
- with gr.Row():
- with gr.Column():
- put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
- different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
- with gr.Column():
- prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
- variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
- with gr.Column():
- margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
-
- return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
-
- def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
- modules.processing.fix_seed(p)
- # Raise error if promp type is not positive or negative
- if prompt_type not in ["positive", "negative"]:
- raise ValueError(f"Unknown prompt type {prompt_type}")
- # Raise error if variations delimiter is not comma or space
- if variations_delimiter not in ["comma", "space"]:
- raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
-
- prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
- original_prompt = prompt[0] if type(prompt) == list else prompt
- positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
-
- delimiter = ", " if variations_delimiter == "comma" else " "
-
- all_prompts = []
- prompt_matrix_parts = original_prompt.split("|")
- combination_count = 2 ** (len(prompt_matrix_parts) - 1)
- for combination_num in range(combination_count):
- selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
-
- if put_at_start:
- selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
- else:
- selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
-
- all_prompts.append(delimiter.join(selected_prompts))
-
- p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
- p.do_not_save_grid = True
-
- print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
-
- if prompt_type == "positive":
- p.prompt = all_prompts
- else:
- p.negative_prompt = all_prompts
- p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
- p.prompt_for_display = positive_prompt
- processed = process_images(p)
-
- grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
- grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
- processed.images.insert(0, grid)
- processed.index_of_first_image = 1
- processed.infotexts.insert(0, processed.infotexts[0])
-
- if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
-
- return processed
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
deleted file mode 100644
index 76dc5778b2754c3eda49d1b66220ebec42265623..0000000000000000000000000000000000000000
--- a/scripts/prompts_from_file.py
+++ /dev/null
@@ -1,177 +0,0 @@
-import copy
-import math
-import os
-import random
-import sys
-import traceback
-import shlex
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import sd_samplers
-from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
-
-
-def process_string_tag(tag):
- return tag
-
-
-def process_int_tag(tag):
- return int(tag)
-
-
-def process_float_tag(tag):
- return float(tag)
-
-
-def process_boolean_tag(tag):
- return True if (tag == "true") else False
-
-
-prompt_tags = {
- "sd_model": None,
- "outpath_samples": process_string_tag,
- "outpath_grids": process_string_tag,
- "prompt_for_display": process_string_tag,
- "prompt": process_string_tag,
- "negative_prompt": process_string_tag,
- "styles": process_string_tag,
- "seed": process_int_tag,
- "subseed_strength": process_float_tag,
- "subseed": process_int_tag,
- "seed_resize_from_h": process_int_tag,
- "seed_resize_from_w": process_int_tag,
- "sampler_index": process_int_tag,
- "sampler_name": process_string_tag,
- "batch_size": process_int_tag,
- "n_iter": process_int_tag,
- "steps": process_int_tag,
- "cfg_scale": process_float_tag,
- "width": process_int_tag,
- "height": process_int_tag,
- "restore_faces": process_boolean_tag,
- "tiling": process_boolean_tag,
- "do_not_save_samples": process_boolean_tag,
- "do_not_save_grid": process_boolean_tag
-}
-
-
-def cmdargs(line):
- args = shlex.split(line)
- pos = 0
- res = {}
-
- while pos < len(args):
- arg = args[pos]
-
- assert arg.startswith("--"), f'must start with "--": {arg}'
- assert pos+1 < len(args), f'missing argument for command line option {arg}'
-
- tag = arg[2:]
-
- if tag == "prompt" or tag == "negative_prompt":
- pos += 1
- prompt = args[pos]
- pos += 1
- while pos < len(args) and not args[pos].startswith("--"):
- prompt += " "
- prompt += args[pos]
- pos += 1
- res[tag] = prompt
- continue
-
-
- func = prompt_tags.get(tag, None)
- assert func, f'unknown commandline option: {arg}'
-
- val = args[pos+1]
- if tag == "sampler_name":
- val = sd_samplers.samplers_map.get(val.lower(), None)
-
- res[tag] = func(val)
-
- pos += 2
-
- return res
-
-
-def load_prompt_file(file):
- if file is None:
- lines = []
- else:
- lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
-
- return None, "\n".join(lines), gr.update(lines=7)
-
-
-class Script(scripts.Script):
- def title(self):
- return "Prompts from file or textbox"
-
- def ui(self, is_img2img):
- checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
- checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
-
- prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
- file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
-
- file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
-
- # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
- # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
- # be unclear to the user that shift-enter is needed.
- prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
- return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
-
- def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
- lines = [x.strip() for x in prompt_txt.splitlines()]
- lines = [x for x in lines if len(x) > 0]
-
- p.do_not_save_grid = True
-
- job_count = 0
- jobs = []
-
- for line in lines:
- if "--" in line:
- try:
- args = cmdargs(line)
- except Exception:
- print(f"Error parsing line {line} as commandline:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- args = {"prompt": line}
- else:
- args = {"prompt": line}
-
- job_count += args.get("n_iter", p.n_iter)
-
- jobs.append(args)
-
- print(f"Will process {len(lines)} lines in {job_count} jobs.")
- if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
- p.seed = int(random.randrange(4294967294))
-
- state.job_count = job_count
-
- images = []
- all_prompts = []
- infotexts = []
- for n, args in enumerate(jobs):
- state.job = f"{state.job_no + 1} out of {state.job_count}"
-
- copy_p = copy.copy(p)
- for k, v in args.items():
- setattr(copy_p, k, v)
-
- proc = process_images(copy_p)
- images += proc.images
-
- if checkbox_iterate:
- p.seed = p.seed + (p.batch_size * p.n_iter)
- all_prompts += proc.all_prompts
- infotexts += proc.infotexts
-
- return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
deleted file mode 100644
index 332d76d918e53e5c2ed8d8cc1371413f1f29d7ee..0000000000000000000000000000000000000000
--- a/scripts/sd_upscale.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import math
-
-import modules.scripts as scripts
-import gradio as gr
-from PIL import Image
-
-from modules import processing, shared, sd_samplers, images, devices
-from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
-
-
-class Script(scripts.Script):
- def title(self):
- return "SD upscale"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
- overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
- scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
- upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
-
- return [info, overlap, upscaler_index, scale_factor]
-
- def run(self, p, _, overlap, upscaler_index, scale_factor):
- if isinstance(upscaler_index, str):
- upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
- processing.fix_seed(p)
- upscaler = shared.sd_upscalers[upscaler_index]
-
- p.extra_generation_params["SD upscale overlap"] = overlap
- p.extra_generation_params["SD upscale upscaler"] = upscaler.name
-
- initial_info = None
- seed = p.seed
-
- init_img = p.init_images[0]
- init_img = images.flatten(init_img, opts.img2img_background_color)
-
- if upscaler.name != "None":
- img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
- else:
- img = init_img
-
- devices.torch_gc()
-
- grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
-
- batch_size = p.batch_size
- upscale_count = p.n_iter
- p.n_iter = 1
- p.do_not_save_grid = True
- p.do_not_save_samples = True
-
- work = []
-
- for y, h, row in grid.tiles:
- for tiledata in row:
- work.append(tiledata[2])
-
- batch_count = math.ceil(len(work) / batch_size)
- state.job_count = batch_count * upscale_count
-
- print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
-
- result_images = []
- for n in range(upscale_count):
- start_seed = seed + n
- p.seed = start_seed
-
- work_results = []
- for i in range(batch_count):
- p.batch_size = batch_size
- p.init_images = work[i * batch_size:(i + 1) * batch_size]
-
- state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
- processed = processing.process_images(p)
-
- if initial_info is None:
- initial_info = processed.info
-
- p.seed = processed.seed + 1
- work_results += processed.images
-
- image_index = 0
- for y, h, row in grid.tiles:
- for tiledata in row:
- tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
- image_index += 1
-
- combined_image = images.combine_grid(grid)
- result_images.append(combined_image)
-
- if opts.samples_save:
- images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
-
- processed = Processed(p, result_images, seed, initial_info)
-
- return processed
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
deleted file mode 100644
index 53511b121a2c401f6980d9a3c41cabe9b1c099fe..0000000000000000000000000000000000000000
--- a/scripts/xyz_grid.py
+++ /dev/null
@@ -1,620 +0,0 @@
-from collections import namedtuple
-from copy import copy
-from itertools import permutations, chain
-import random
-import csv
-from io import StringIO
-from PIL import Image
-import numpy as np
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
-from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-import modules.sd_samplers
-import modules.sd_models
-import modules.sd_vae
-import glob
-import os
-import re
-
-from modules.ui_components import ToolButton
-
-fill_values_symbol = "\U0001f4d2" # 📒
-
-AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
-
-
-def apply_field(field):
- def fun(p, x, xs):
- setattr(p, field, x)
-
- return fun
-
-
-def apply_prompt(p, x, xs):
- if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
- raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
-
- p.prompt = p.prompt.replace(xs[0], x)
- p.negative_prompt = p.negative_prompt.replace(xs[0], x)
-
-
-def apply_order(p, x, xs):
- token_order = []
-
- # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
- for token in x:
- token_order.append((p.prompt.find(token), token))
-
- token_order.sort(key=lambda t: t[0])
-
- prompt_parts = []
-
- # Split the prompt up, taking out the tokens
- for _, token in token_order:
- n = p.prompt.find(token)
- prompt_parts.append(p.prompt[0:n])
- p.prompt = p.prompt[n + len(token):]
-
- # Rebuild the prompt with the tokens in the order we want
- prompt_tmp = ""
- for idx, part in enumerate(prompt_parts):
- prompt_tmp += part
- prompt_tmp += x[idx]
- p.prompt = prompt_tmp + p.prompt
-
-
-def apply_sampler(p, x, xs):
- sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
- if sampler_name is None:
- raise RuntimeError(f"Unknown sampler: {x}")
-
- p.sampler_name = sampler_name
-
-
-def confirm_samplers(p, xs):
- for x in xs:
- if x.lower() not in sd_samplers.samplers_map:
- raise RuntimeError(f"Unknown sampler: {x}")
-
-
-def apply_checkpoint(p, x, xs):
- info = modules.sd_models.get_closet_checkpoint_match(x)
- if info is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
- modules.sd_models.reload_model_weights(shared.sd_model, info)
-
-
-def confirm_checkpoints(p, xs):
- for x in xs:
- if modules.sd_models.get_closet_checkpoint_match(x) is None:
- raise RuntimeError(f"Unknown checkpoint: {x}")
-
-
-def apply_clip_skip(p, x, xs):
- opts.data["CLIP_stop_at_last_layers"] = x
-
-
-def apply_upscale_latent_space(p, x, xs):
- if x.lower().strip() != '0':
- opts.data["use_scale_latent_for_hires_fix"] = True
- else:
- opts.data["use_scale_latent_for_hires_fix"] = False
-
-
-def find_vae(name: str):
- if name.lower() in ['auto', 'automatic']:
- return modules.sd_vae.unspecified
- if name.lower() == 'none':
- return None
- else:
- choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
- if len(choices) == 0:
- print(f"No VAE found for {name}; using automatic")
- return modules.sd_vae.unspecified
- else:
- return modules.sd_vae.vae_dict[choices[0]]
-
-
-def apply_vae(p, x, xs):
- modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
-
-
-def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
- p.styles.extend(x.split(','))
-
-
-def format_value_add_label(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
-
- return f"{opt.label}: {x}"
-
-
-def format_value(p, opt, x):
- if type(x) == float:
- x = round(x, 8)
- return x
-
-
-def format_value_join_list(p, opt, x):
- return ", ".join(x)
-
-
-def do_nothing(p, x, xs):
- pass
-
-
-def format_nothing(p, opt, x):
- return ""
-
-
-def str_permutations(x):
- """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
- return x
-
-
-class AxisOption:
- def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
- self.label = label
- self.type = type
- self.apply = apply
- self.format_value = format_value
- self.confirm = confirm
- self.cost = cost
- self.choices = choices
-
-
-class AxisOptionImg2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = True
-
-class AxisOptionTxt2Img(AxisOption):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_img2img = False
-
-
-axis_options = [
- AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
- AxisOption("Seed", int, apply_field("seed")),
- AxisOption("Var. seed", int, apply_field("subseed")),
- AxisOption("Var. strength", float, apply_field("subseed_strength")),
- AxisOption("Steps", int, apply_field("steps")),
- AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
- AxisOption("CFG Scale", float, apply_field("cfg_scale")),
- AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
- AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
- AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
- AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
- AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
- AxisOption("Sigma Churn", float, apply_field("s_churn")),
- AxisOption("Sigma min", float, apply_field("s_tmin")),
- AxisOption("Sigma max", float, apply_field("s_tmax")),
- AxisOption("Sigma noise", float, apply_field("s_noise")),
- AxisOption("Eta", float, apply_field("eta")),
- AxisOption("Clip skip", int, apply_clip_skip),
- AxisOption("Denoising", float, apply_field("denoising_strength")),
- AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
- AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
- AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
- AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
-]
-
-
-def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
- hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
- ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
- title_texts = [[images.GridAnnotation(z)] for z in z_labels]
-
- # Temporary list of all the images that are generated to be populated into the grid.
- # Will be filled with empty images for any individual step that fails to process properly
- image_cache = [None] * (len(xs) * len(ys) * len(zs))
-
- processed_result = None
- cell_mode = "P"
- cell_size = (1, 1)
-
- state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
-
- def process_cell(x, y, z, ix, iy, iz):
- nonlocal image_cache, processed_result, cell_mode, cell_size
-
- def index(ix, iy, iz):
- return ix + iy * len(xs) + iz * len(xs) * len(ys)
-
- state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
-
- processed: Processed = cell(x, y, z)
-
- try:
- # this dereference will throw an exception if the image was not processed
- # (this happens in cases such as if the user stops the process from the UI)
- processed_image = processed.images[0]
-
- if processed_result is None:
- # Use our first valid processed result as a template container to hold our full results
- processed_result = copy(processed)
- cell_mode = processed_image.mode
- cell_size = processed_image.size
- processed_result.images = [Image.new(cell_mode, cell_size)]
- processed_result.all_prompts = [processed.prompt]
- processed_result.all_seeds = [processed.seed]
- processed_result.infotexts = [processed.infotexts[0]]
-
- image_cache[index(ix, iy, iz)] = processed_image
- if include_lone_images:
- processed_result.images.append(processed_image)
- processed_result.all_prompts.append(processed.prompt)
- processed_result.all_seeds.append(processed.seed)
- processed_result.infotexts.append(processed.infotexts[0])
- except:
- image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
-
- if first_axes_processed == 'x':
- for ix, x in enumerate(xs):
- if second_axes_processed == 'y':
- for iy, y in enumerate(ys):
- for iz, z in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iz, z in enumerate(zs):
- for iy, y in enumerate(ys):
- process_cell(x, y, z, ix, iy, iz)
- elif first_axes_processed == 'y':
- for iy, y in enumerate(ys):
- if second_axes_processed == 'x':
- for ix, x in enumerate(xs):
- for iz, z in enumerate(zs):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iz, z in enumerate(zs):
- for ix, x in enumerate(xs):
- process_cell(x, y, z, ix, iy, iz)
- elif first_axes_processed == 'z':
- for iz, z in enumerate(zs):
- if second_axes_processed == 'x':
- for ix, x in enumerate(xs):
- for iy, y in enumerate(ys):
- process_cell(x, y, z, ix, iy, iz)
- else:
- for iy, y in enumerate(ys):
- for ix, x in enumerate(xs):
- process_cell(x, y, z, ix, iy, iz)
-
- if not processed_result:
- print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
- return Processed(p, [])
-
- sub_grids = [None] * len(zs)
- for i in range(len(zs)):
- start_index = i * len(xs) * len(ys)
- end_index = start_index + len(xs) * len(ys)
- grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
- if draw_legend:
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
- sub_grids[i] = grid
- if include_sub_grids and len(zs) > 1:
- processed_result.images.insert(i+1, grid)
-
- sub_grid_size = sub_grids[0].size
- z_grid = images.image_grid(sub_grids, rows=1)
- if draw_legend:
- z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
- processed_result.images[0] = z_grid
-
- return processed_result, sub_grids
-
-
-class SharedSettingsStackHelper(object):
- def __enter__(self):
- self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
- self.vae = opts.sd_vae
-
- def __exit__(self, exc_type, exc_value, tb):
- opts.data["sd_vae"] = self.vae
- modules.sd_models.reload_model_weights()
- modules.sd_vae.reload_vae_weights()
-
- opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
-
-
-re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
-re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
-
-re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
-re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
-
-
-class Script(scripts.Script):
- def title(self):
- return "X/Y/Z plot"
-
- def ui(self, is_img2img):
- self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
-
- with gr.Row():
- with gr.Column(scale=19):
- with gr.Row():
- x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
- x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
- fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
-
- with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
- y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
- fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
-
- with gr.Row():
- z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
- z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
- fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
-
- with gr.Row(variant="compact", elem_id="axis_options"):
- with gr.Column():
- draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
- with gr.Column():
- include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
- include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
- with gr.Column():
- margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
-
- with gr.Row(variant="compact", elem_id="swap_axes"):
- swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
- swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
- swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
-
- def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
- return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
-
- xy_swap_args = [x_type, x_values, y_type, y_values]
- swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
- yz_swap_args = [y_type, y_values, z_type, z_values]
- swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
- xz_swap_args = [x_type, x_values, z_type, z_values]
- swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
-
- def fill(x_type):
- axis = self.current_axis_options[x_type]
- return ", ".join(axis.choices()) if axis.choices else gr.update()
-
- fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
- fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
- fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
-
- def select_axis(x_type):
- return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
-
- x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
- y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
- z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
-
- self.infotext_fields = (
- (x_type, "X Type"),
- (x_values, "X Values"),
- (y_type, "Y Type"),
- (y_values, "Y Values"),
- (z_type, "Z Type"),
- (z_values, "Z Values"),
- )
-
- return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
-
- def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
- if not no_fixed_seeds:
- modules.processing.fix_seed(p)
-
- if not opts.return_grid:
- p.batch_size = 1
-
- def process_axis(opt, vals):
- if opt.label == 'Nothing':
- return [0]
-
- valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
-
- if opt.type == int:
- valslist_ext = []
-
- for val in valslist:
- m = re_range.fullmatch(val)
- mc = re_range_count.fullmatch(val)
- if m is not None:
- start = int(m.group(1))
- end = int(m.group(2))+1
- step = int(m.group(3)) if m.group(3) is not None else 1
-
- valslist_ext += list(range(start, end, step))
- elif mc is not None:
- start = int(mc.group(1))
- end = int(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
- else:
- valslist_ext.append(val)
-
- valslist = valslist_ext
- elif opt.type == float:
- valslist_ext = []
-
- for val in valslist:
- m = re_range_float.fullmatch(val)
- mc = re_range_count_float.fullmatch(val)
- if m is not None:
- start = float(m.group(1))
- end = float(m.group(2))
- step = float(m.group(3)) if m.group(3) is not None else 1
-
- valslist_ext += np.arange(start, end + step, step).tolist()
- elif mc is not None:
- start = float(mc.group(1))
- end = float(mc.group(2))
- num = int(mc.group(3)) if mc.group(3) is not None else 1
-
- valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
- else:
- valslist_ext.append(val)
-
- valslist = valslist_ext
- elif opt.type == str_permutations:
- valslist = list(permutations(valslist))
-
- valslist = [opt.type(x) for x in valslist]
-
- # Confirm options are valid before starting
- if opt.confirm:
- opt.confirm(p, valslist)
-
- return valslist
-
- x_opt = self.current_axis_options[x_type]
- xs = process_axis(x_opt, x_values)
-
- y_opt = self.current_axis_options[y_type]
- ys = process_axis(y_opt, y_values)
-
- z_opt = self.current_axis_options[z_type]
- zs = process_axis(z_opt, z_values)
-
- def fix_axis_seeds(axis_opt, axis_list):
- if axis_opt.label in ['Seed', 'Var. seed']:
- return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
- else:
- return axis_list
-
- if not no_fixed_seeds:
- xs = fix_axis_seeds(x_opt, xs)
- ys = fix_axis_seeds(y_opt, ys)
- zs = fix_axis_seeds(z_opt, zs)
-
- if x_opt.label == 'Steps':
- total_steps = sum(xs) * len(ys) * len(zs)
- elif y_opt.label == 'Steps':
- total_steps = sum(ys) * len(xs) * len(zs)
- elif z_opt.label == 'Steps':
- total_steps = sum(zs) * len(xs) * len(ys)
- else:
- total_steps = p.steps * len(xs) * len(ys) * len(zs)
-
- if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
- if x_opt.label == "Hires steps":
- total_steps += sum(xs) * len(ys) * len(zs)
- elif y_opt.label == "Hires steps":
- total_steps += sum(ys) * len(xs) * len(zs)
- elif z_opt.label == "Hires steps":
- total_steps += sum(zs) * len(xs) * len(ys)
- elif p.hr_second_pass_steps:
- total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
- else:
- total_steps *= 2
-
- total_steps *= p.n_iter
-
- image_cell_count = p.n_iter * p.batch_size
- cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
- plural_s = 's' if len(zs) > 1 else ''
- print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
- shared.total_tqdm.updateTotal(total_steps)
-
- grid_infotext = [None]
-
- state.xyz_plot_x = AxisInfo(x_opt, xs)
- state.xyz_plot_y = AxisInfo(y_opt, ys)
- state.xyz_plot_z = AxisInfo(z_opt, zs)
-
- # If one of the axes is very slow to change between (like SD model
- # checkpoint), then make sure it is in the outer iteration of the nested
- # `for` loop.
- first_axes_processed = 'x'
- second_axes_processed = 'y'
- if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
- first_axes_processed = 'x'
- if y_opt.cost > z_opt.cost:
- second_axes_processed = 'y'
- else:
- second_axes_processed = 'z'
- elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
- first_axes_processed = 'y'
- if x_opt.cost > z_opt.cost:
- second_axes_processed = 'x'
- else:
- second_axes_processed = 'z'
- elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
- first_axes_processed = 'z'
- if x_opt.cost > y_opt.cost:
- second_axes_processed = 'x'
- else:
- second_axes_processed = 'y'
-
- def cell(x, y, z):
- if shared.state.interrupted:
- return Processed(p, [], p.seed, "")
-
- pc = copy(p)
- pc.styles = pc.styles[:]
- x_opt.apply(pc, x, xs)
- y_opt.apply(pc, y, ys)
- z_opt.apply(pc, z, zs)
-
- res = process_images(pc)
-
- if grid_infotext[0] is None:
- pc.extra_generation_params = copy(pc.extra_generation_params)
- pc.extra_generation_params['Script'] = self.title()
-
- if x_opt.label != 'Nothing':
- pc.extra_generation_params["X Type"] = x_opt.label
- pc.extra_generation_params["X Values"] = x_values
- if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
-
- if y_opt.label != 'Nothing':
- pc.extra_generation_params["Y Type"] = y_opt.label
- pc.extra_generation_params["Y Values"] = y_values
- if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
-
- if z_opt.label != 'Nothing':
- pc.extra_generation_params["Z Type"] = z_opt.label
- pc.extra_generation_params["Z Values"] = z_values
- if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
- pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
-
- grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
-
- return res
-
- with SharedSettingsStackHelper():
- processed, sub_grids = draw_xyz_grid(
- p,
- xs=xs,
- ys=ys,
- zs=zs,
- x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
- y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
- z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
- cell=cell,
- draw_legend=draw_legend,
- include_lone_images=include_lone_images,
- include_sub_grids=include_sub_grids,
- first_axes_processed=first_axes_processed,
- second_axes_processed=second_axes_processed,
- margin_size=margin_size
- )
-
- if opts.grid_save and len(sub_grids) > 1:
- for sub_grid in sub_grids:
- images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
-
- if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
-
- return processed
diff --git a/style.css b/style.css
deleted file mode 100644
index 05572f662d1aa817d4000e16789c70899c51c648..0000000000000000000000000000000000000000
--- a/style.css
+++ /dev/null
@@ -1,961 +0,0 @@
-.container {
- max-width: 100%;
-}
-
-.token-counter{
- position: absolute;
- display: inline-block;
- right: 2em;
- min-width: 0 !important;
- width: auto;
- z-index: 100;
-}
-
-.token-counter.error span{
- box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075);
- border: 2px solid rgba(255,0,0,0.4) !important;
-}
-
-.token-counter div{
- display: inline;
-}
-
-.token-counter span{
- padding: 0.1em 0.75em;
-}
-
-#sh{
- min-width: 2em;
- min-height: 2em;
- max-width: 2em;
- max-height: 2em;
- flex-grow: 0;
- padding-left: 0.25em;
- padding-right: 0.25em;
- margin: 0.1em 0;
- opacity: 0%;
- cursor: default;
-}
-
-.output-html p {margin: 0 0.5em;}
-
-.row > *,
-.row > .gr-form > * {
- min-width: min(120px, 100%);
- flex: 1 1 0%;
-}
-
-.performance {
- font-size: 0.85em;
- color: #444;
-}
-
-.performance p{
- display: inline-block;
-}
-
-.performance .time {
- margin-right: 0;
-}
-
-.performance .vram {
-}
-
-#txt2img_generate, #img2img_generate {
- min-height: 4.5em;
-}
-
-@media screen and (min-width: 2500px) {
- #txt2img_gallery, #img2img_gallery {
- min-height: 768px;
- }
-}
-
-#txt2img_gallery img, #img2img_gallery img{
- object-fit: scale-down;
-}
-#txt2img_actions_column, #img2img_actions_column {
- margin: 0.35rem 0.75rem 0.35rem 0;
-}
-#script_list {
- padding: .625rem .75rem 0 .625rem;
-}
-.justify-center.overflow-x-scroll {
- justify-content: left;
-}
-
-.justify-center.overflow-x-scroll button:first-of-type {
- margin-left: auto;
-}
-
-.justify-center.overflow-x-scroll button:last-of-type {
- margin-right: auto;
-}
-
-[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
- min-width: 2.3em;
- height: 2.5em;
- flex-grow: 0;
- padding-left: 0.25em;
- padding-right: 0.25em;
-}
-
-#hidden_element{
- display: none;
-}
-
-[id$=_seed_row], [id$=_subseed_row]{
- gap: 0.5rem;
- padding: 0.6em;
-}
-
-[id$=_subseed_show_box]{
- min-width: auto;
- flex-grow: 0;
-}
-
-[id$=_subseed_show_box] > div{
- border: 0;
- height: 100%;
-}
-
-[id$=_subseed_show]{
- min-width: auto;
- flex-grow: 0;
- padding: 0;
-}
-
-[id$=_subseed_show] label{
- height: 100%;
-}
-
-#txt2img_actions_column, #img2img_actions_column{
- gap: 0;
- margin-right: .75rem;
-}
-
-#txt2img_tools, #img2img_tools{
- gap: 0.4em;
-}
-
-#interrogate_col{
- min-width: 0 !important;
- max-width: 8em !important;
- margin-right: 1em;
- gap: 0;
-}
-#interrogate, #deepbooru{
- margin: 0em 0.25em 0.5em 0.25em;
- min-width: 8em;
- max-width: 8em;
-}
-
-#style_pos_col, #style_neg_col{
- min-width: 8em !important;
-}
-
-#txt2img_styles_row, #img2img_styles_row{
- gap: 0.25em;
- margin-top: 0.3em;
-}
-
-#txt2img_styles_row > button, #img2img_styles_row > button{
- margin: 0;
-}
-
-#txt2img_styles, #img2img_styles{
- padding: 0;
-}
-
-#txt2img_styles > label > div, #img2img_styles > label > div{
- min-height: 3.2em;
-}
-
-ul.list-none{
- max-height: 35em;
- z-index: 2000;
-}
-
-.gr-form{
- background: transparent;
-}
-
-.my-4{
- margin-top: 0;
- margin-bottom: 0;
-}
-
-#resize_mode{
- flex: 1.5;
-}
-
-button{
- align-self: stretch !important;
-}
-
-.overflow-hidden, .gr-panel{
- overflow: visible !important;
-}
-
-#x_type, #y_type{
- max-width: 10em;
-}
-
-#txt2img_preview, #img2img_preview, #ti_preview{
- position: absolute;
- width: 320px;
- left: 0;
- right: 0;
- margin-left: auto;
- margin-right: auto;
- margin-top: 34px;
- z-index: 100;
- border: none;
- border-top-left-radius: 0;
- border-top-right-radius: 0;
-}
-
-@media screen and (min-width: 768px) {
- #txt2img_preview, #img2img_preview, #ti_preview {
- position: absolute;
- }
-}
-
-@media screen and (max-width: 767px) {
- #txt2img_preview, #img2img_preview, #ti_preview {
- position: relative;
- }
-}
-
-#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
- display: none;
-}
-
-fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
- position: absolute;
- top: -0.7em;
- line-height: 1.2em;
- padding: 0;
- margin: 0 0.5em;
-
- background-color: white;
- box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
-
- z-index: 300;
-}
-
-.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
- background-color: rgb(31, 41, 55);
- box-shadow: none;
- border: 1px solid rgba(128, 128, 128, 0.1);
- border-radius: 6px;
- padding: 0.1em 0.5em;
-}
-
-#txt2img_column_batch, #img2img_column_batch{
- min-width: min(13.5em, 100%) !important;
-}
-
-#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
- position: relative;
- border: none;
- margin-right: 8em;
-}
-
-#settings .gr-panel div.flex-col div.justify-between div{
- position: relative;
- z-index: 200;
-}
-
-#settings{
- display: block;
-}
-
-#settings > div{
- border: none;
- margin-left: 10em;
-}
-
-#settings > div.flex-wrap{
- float: left;
- display: block;
- margin-left: 0;
- width: 10em;
-}
-
-#settings > div.flex-wrap button{
- display: block;
- border: none;
- text-align: left;
-}
-
-#settings_result{
- height: 1.4em;
- margin: 0 1.2em;
-}
-
-input[type="range"]{
- margin: 0.5em 0 -0.3em 0;
-}
-
-#mask_bug_info {
- text-align: center;
- display: block;
- margin-top: -0.75em;
- margin-bottom: -0.75em;
-}
-
-#txt2img_negative_prompt, #img2img_negative_prompt{
-}
-
-/* gradio 3.8 adds opacity to progressbar which makes it blink; disable it here */
-.transition.opacity-20 {
- opacity: 1 !important;
-}
-
-/* more gradio's garbage cleanup */
-.min-h-\[4rem\] { min-height: unset !important; }
-.min-h-\[6rem\] { min-height: unset !important; }
-
-.progressDiv{
- position: relative;
- height: 20px;
- background: #b4c0cc;
- border-radius: 3px !important;
- margin-bottom: -3px;
-}
-
-.dark .progressDiv{
- background: #424c5b;
-}
-
-.progressDiv .progress{
- width: 0%;
- height: 20px;
- background: #0060df;
- color: white;
- font-weight: bold;
- line-height: 20px;
- padding: 0 8px 0 0;
- text-align: right;
- border-radius: 3px;
- overflow: visible;
- white-space: nowrap;
- padding: 0 0.5em;
-}
-
-.livePreview{
- position: absolute;
- z-index: 300;
- background-color: white;
- margin: -4px;
-}
-
-.dark .livePreview{
- background-color: rgb(17 24 39 / var(--tw-bg-opacity));
-}
-
-.livePreview img{
- position: absolute;
- object-fit: contain;
- width: 100%;
- height: 100%;
-}
-
-#lightboxModal{
- display: none;
- position: fixed;
- z-index: 1001;
- padding-top: 100px;
- left: 0;
- top: 0;
- width: 100%;
- height: 100%;
- overflow: auto;
- background-color: rgba(20, 20, 20, 0.95);
- user-select: none;
- -webkit-user-select: none;
-}
-
-.modalControls {
- display: grid;
- grid-template-columns: 32px 32px 32px 1fr 32px;
- grid-template-areas: "zoom tile save space close";
- position: absolute;
- top: 0;
- left: 0;
- right: 0;
- padding: 16px;
- gap: 16px;
- background-color: rgba(0,0,0,0.2);
-}
-
-.modalClose {
- grid-area: close;
-}
-
-.modalZoom {
- grid-area: zoom;
-}
-
-.modalSave {
- grid-area: save;
-}
-
-.modalTileImage {
- grid-area: tile;
-}
-
-.modalClose,
-.modalZoom,
-.modalTileImage {
- color: white;
- font-size: 35px;
- font-weight: bold;
- cursor: pointer;
-}
-
-.modalSave {
- color: white;
- font-size: 28px;
- margin-top: 8px;
- font-weight: bold;
- cursor: pointer;
-}
-
-.modalClose:hover,
-.modalClose:focus,
-.modalSave:hover,
-.modalSave:focus,
-.modalZoom:hover,
-.modalZoom:focus {
- color: #999;
- text-decoration: none;
- cursor: pointer;
-}
-
-#modalImage {
- display: block;
- margin-left: auto;
- margin-right: auto;
- margin-top: auto;
- width: auto;
-}
-
-.modalImageFullscreen {
- object-fit: contain;
- height: 90%;
-}
-
-.modalPrev,
-.modalNext {
- cursor: pointer;
- position: absolute;
- top: 50%;
- width: auto;
- padding: 16px;
- margin-top: -50px;
- color: white;
- font-weight: bold;
- font-size: 20px;
- transition: 0.6s ease;
- border-radius: 0 3px 3px 0;
- user-select: none;
- -webkit-user-select: none;
-}
-
-.modalNext {
- right: 0;
- border-radius: 3px 0 0 3px;
-}
-
-.modalPrev:hover,
-.modalNext:hover {
- background-color: rgba(0, 0, 0, 0.8);
-}
-
-#imageARPreview{
- position:absolute;
- top:0px;
- left:0px;
- border:2px solid red;
- background:rgba(255, 0, 0, 0.3);
- z-index: 900;
- pointer-events:none;
- display:none
-}
-
-#txt2img_generate_box, #img2img_generate_box{
- position: relative;
-}
-
-#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{
- position: absolute;
- width: 50%;
- height: 100%;
- background: #b4c0cc;
- display: none;
-}
-
-#txt2img_interrupt, #img2img_interrupt{
- left: 0;
- border-radius: 0.5rem 0 0 0.5rem;
-}
-#txt2img_skip, #img2img_skip{
- right: 0;
- border-radius: 0 0.5rem 0.5rem 0;
-}
-
-.red {
- color: red;
-}
-
-.gallery-item {
- --tw-bg-opacity: 0 !important;
-}
-
-#context-menu{
- z-index:9999;
- position:absolute;
- display:block;
- padding:0px 0;
- border:2px solid #a55000;
- border-radius:8px;
- box-shadow:1px 1px 2px #CE6400;
- width: 200px;
-}
-
-.context-menu-items{
- list-style: none;
- margin: 0;
- padding: 0;
-}
-
-.context-menu-items a{
- display:block;
- padding:5px;
- cursor:pointer;
-}
-
-.context-menu-items a:hover{
- background: #a55000;
-}
-
-#quicksettings {
- width: fit-content;
-}
-
-#quicksettings > div, #quicksettings > fieldset{
- max-width: 24em;
- min-width: 24em;
- padding: 0;
- border: none;
- box-shadow: none;
- background: none;
- margin-right: 10px;
-}
-
-#quicksettings > div > div > div > label > span {
- position: relative;
- margin-right: 9em;
- margin-bottom: -1em;
-}
-
-canvas[key="mask"] {
- z-index: 12 !important;
- filter: invert();
- mix-blend-mode: multiply;
- pointer-events: none;
-}
-
-
-/* gradio 3.4.1 stuff for editable scrollbar values */
-.gr-box > div > div > input.gr-text-input{
- position: absolute;
- right: 0.5em;
- top: -0.6em;
- z-index: 400;
- width: 6em;
-}
-#quicksettings .gr-box > div > div > input.gr-text-input {
- top: -1.12em;
-}
-
-.row.gr-compact{
- overflow: visible;
-}
-
-#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
-#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img,
-#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img,
-#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img
-{
- height: 480px !important;
- max-height: 480px !important;
- min-height: 480px !important;
-}
-
-/* Extensions */
-
-#tab_extensions table{
- border-collapse: collapse;
-}
-
-#tab_extensions table td, #tab_extensions table th{
- border: 1px solid #ccc;
- padding: 0.25em 0.5em;
-}
-
-#tab_extensions table input[type="checkbox"]{
- margin-right: 0.5em;
-}
-
-#tab_extensions button{
- max-width: 16em;
-}
-
-#tab_extensions input[disabled="disabled"]{
- opacity: 0.5;
-}
-
-.extension-tag{
- font-weight: bold;
- font-size: 95%;
-}
-
-#available_extensions .info{
- margin: 0;
-}
-
-#available_extensions .date_added{
- opacity: 0.85;
- font-size: 90%;
-}
-
-#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
- min-width: auto;
- padding-left: 0.5em;
- padding-right: 0.5em;
-}
-
-.gr-form{
- background-color: white;
-}
-
-.dark .gr-form{
- background-color: rgb(31 41 55 / var(--tw-bg-opacity));
-}
-
-.gr-button-tool, .gr-button-tool-top{
- max-width: 2.5em;
- min-width: 2.5em !important;
- height: 2.4em;
-}
-
-.gr-button-tool{
- margin: 0.6em 0em 0.55em 0;
-}
-
-.gr-button-tool-top, #settings .gr-button-tool{
- margin: 1.6em 0.7em 0.55em 0;
-}
-
-
-#modelmerger_results_container{
- margin-top: 1em;
- overflow: visible;
-}
-
-#modelmerger_models{
- gap: 0;
-}
-
-
-#quicksettings .gr-button-tool{
- margin: 0;
- border-color: unset;
- background-color: unset;
-}
-
-#modelmerger_interp_description>p {
- margin: 0!important;
- text-align: center;
-}
-#modelmerger_interp_description {
- margin: 0.35rem 0.75rem 1.23rem;
-}
-#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
- padding-top: 0.9em;
- padding-bottom: 0.9em;
-}
-#txt2img_settings {
- padding-top: 1.16em;
- padding-bottom: 0.9em;
-}
-#img2img_settings {
- padding-bottom: 0.9em;
-}
-
-#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
- border: none;
- padding-bottom: 0.5em;
-}
-
-footer {
- display: none !important;
-}
-
-#footer{
- text-align: center;
-}
-
-#footer div{
- display: inline-block;
-}
-
-#footer .versions{
- font-size: 85%;
- opacity: 0.85;
-}
-
-#txtimg_hr_finalres{
- min-height: 0 !important;
- padding: .625rem .75rem;
- margin-left: -0.75em
-
-}
-
-#txtimg_hr_finalres .resolution{
- font-weight: bold;
-}
-
-#txt2img_checkboxes, #img2img_checkboxes{
- margin-bottom: 0.5em;
- margin-left: 0em;
-}
-#txt2img_checkboxes > div, #img2img_checkboxes > div{
- flex: 0;
- white-space: nowrap;
- min-width: auto;
-}
-
-#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
- margin-left: 0em;
-}
-
-#axis_options {
- margin-left: 0em;
-}
-
-.inactive{
- opacity: 0.5;
-}
-
-[id*='_prompt_container']{
- gap: 0;
-}
-
-[id*='_prompt_container'] > div{
- margin: -0.4em 0 0 0;
-}
-
-.gr-compact {
- border: none;
-}
-
-.dark .gr-compact{
- background-color: rgb(31 41 55 / var(--tw-bg-opacity));
- margin-left: 0;
-}
-
-.gr-compact{
- overflow: visible;
-}
-
-.gr-compact > *{
-}
-
-.gr-compact .gr-block, .gr-compact .gr-form{
- border: none;
- box-shadow: none;
-}
-
-.gr-compact .gr-box{
- border-radius: .5rem !important;
- border-width: 1px !important;
-}
-
-#mode_img2img > div > div{
- gap: 0 !important;
-}
-
-[id*='img2img_copy_to_'] {
- border: none;
-}
-
-[id*='img2img_copy_to_'] > button {
-}
-
-[id*='img2img_label_copy_to_'] {
- font-size: 1.0em;
- font-weight: bold;
- text-align: center;
- line-height: 2.4em;
-}
-
-.extra-networks > div > [id *= '_extra_']{
- margin: 0.3em;
-}
-
-.extra-network-subdirs{
- padding: 0.2em 0.35em;
-}
-
-.extra-network-subdirs button{
- margin: 0 0.15em;
-}
-
-#txt2img_extra_networks .search, #img2img_extra_networks .search{
- display: inline-block;
- max-width: 16em;
- margin: 0.3em;
- align-self: center;
-}
-
-#txt2img_extra_view, #img2img_extra_view {
- width: auto;
-}
-
-.extra-network-cards .nocards, .extra-network-thumbs .nocards{
- margin: 1.25em 0.5em 0.5em 0.5em;
-}
-
-.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{
- font-size: 1.5em;
- margin-bottom: 1em;
-}
-
-.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{
- margin-left: 0.5em;
-}
-
-.extra-network-thumbs {
- display: flex;
- flex-flow: row wrap;
- gap: 10px;
-}
-
-.extra-network-thumbs .card {
- height: 6em;
- width: 6em;
- cursor: pointer;
- background-image: url('./file=html/card-no-preview.png');
- background-size: cover;
- background-position: center center;
- position: relative;
-}
-
-.extra-network-thumbs .card:hover .additional a {
- display: block;
-}
-
-.extra-network-thumbs .actions .additional a {
- background-image: url('./file=html/image-update.svg');
- background-repeat: no-repeat;
- background-size: cover;
- background-position: center center;
- position: absolute;
- top: 0;
- left: 0;
- width: 24px;
- height: 24px;
- display: none;
- font-size: 0;
- text-align: -9999;
-}
-
-.extra-network-thumbs .actions .name {
- position: absolute;
- bottom: 0;
- font-size: 10px;
- padding: 3px;
- width: 100%;
- overflow: hidden;
- white-space: nowrap;
- text-overflow: ellipsis;
- background: rgba(0,0,0,.5);
- color: white;
-}
-
-.extra-network-thumbs .card:hover .actions .name {
- white-space: normal;
- word-break: break-all;
-}
-
-.extra-network-cards .card{
- display: inline-block;
- margin: 0.5em;
- width: 16em;
- height: 24em;
- box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
- border-radius: 0.2em;
- position: relative;
-
- background-size: auto 100%;
- background-position: center;
- overflow: hidden;
- cursor: pointer;
-
- background-image: url('./file=html/card-no-preview.png')
-}
-
-.extra-network-cards .card:hover{
- box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
-}
-
-.extra-network-cards .card .actions .additional{
- display: none;
-}
-
-.extra-network-cards .card .actions{
- position: absolute;
- bottom: 0;
- left: 0;
- right: 0;
- padding: 0.5em;
- color: white;
- background: rgba(0,0,0,0.5);
- box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
- text-shadow: 0 0 0.2em black;
-}
-
-.extra-network-cards .card .actions:hover{
- box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
-}
-
-.extra-network-cards .card .actions .name{
- font-size: 1.7em;
- font-weight: bold;
- line-break: anywhere;
-}
-
-.extra-network-cards .card .actions:hover .additional{
- display: block;
-}
-
-.extra-network-cards .card ul{
- margin: 0.25em 0 0.75em 0.25em;
- cursor: unset;
-}
-
-.extra-network-cards .card ul a{
- cursor: pointer;
-}
-
-.extra-network-cards .card ul a:hover{
- color: red;
-}
-
-[id*='_prompt_container'] > div {
- margin: 0!important;
-}
diff --git a/test/__init__.py b/test/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py
deleted file mode 100644
index 0170c511fe54cc6bcf49ec7f75ca7c747de41db5..0000000000000000000000000000000000000000
--- a/test/basic_features/extras_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import unittest
-import requests
-from gradio.processing_utils import encode_pil_to_base64
-from PIL import Image
-
-class TestExtrasWorking(unittest.TestCase):
- def setUp(self):
- self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.extras_single = {
- "resize_mode": 0,
- "show_extras_results": True,
- "gfpgan_visibility": 0,
- "codeformer_visibility": 0,
- "codeformer_weight": 0,
- "upscaling_resize": 2,
- "upscaling_resize_w": 128,
- "upscaling_resize_h": 128,
- "upscaling_crop": True,
- "upscaler_1": "None",
- "upscaler_2": "None",
- "extras_upscaler_2_visibility": 0,
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
- }
-
- def test_simple_upscaling_performed(self):
- self.extras_single["upscaler_1"] = "Lanczos"
- self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200)
-
-
-class TestPngInfoWorking(unittest.TestCase):
- def setUp(self):
- self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.png_info = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
- }
-
- def test_png_info_performed(self):
- self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200)
-
-
-class TestInterrogateWorking(unittest.TestCase):
- def setUp(self):
- self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.interrogate = {
- "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
- "model": "clip"
- }
-
- def test_interrogate_performed(self):
- self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
deleted file mode 100644
index 08c5c903e8382ef4b969b01da87bc69fb06ff2b4..0000000000000000000000000000000000000000
--- a/test/basic_features/img2img_test.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import unittest
-import requests
-from gradio.processing_utils import encode_pil_to_base64
-from PIL import Image
-
-
-class TestImg2ImgWorking(unittest.TestCase):
- def setUp(self):
- self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
- self.simple_img2img = {
- "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
- "resize_mode": 0,
- "denoising_strength": 0.75,
- "mask": None,
- "mask_blur": 4,
- "inpainting_fill": 0,
- "inpaint_full_res": False,
- "inpaint_full_res_padding": 0,
- "inpainting_mask_invert": False,
- "prompt": "example prompt",
- "styles": [],
- "seed": -1,
- "subseed": -1,
- "subseed_strength": 0,
- "seed_resize_from_h": -1,
- "seed_resize_from_w": -1,
- "batch_size": 1,
- "n_iter": 1,
- "steps": 3,
- "cfg_scale": 7,
- "width": 64,
- "height": 64,
- "restore_faces": False,
- "tiling": False,
- "negative_prompt": "",
- "eta": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 1,
- "override_settings": {},
- "sampler_index": "Euler a",
- "include_init_images": False
- }
-
- def test_img2img_simple_performed(self):
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_inpainting_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_inpainting_with_inverted_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
- self.simple_img2img["inpainting_mask_invert"] = True
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_img2img_sd_upscale_performed(self):
- self.simple_img2img["script_name"] = "sd upscale"
- self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
-
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
deleted file mode 100644
index 5aa43a44a3e3818d7220a98acf5a6a504cdca3e3..0000000000000000000000000000000000000000
--- a/test/basic_features/txt2img_test.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import unittest
-import requests
-
-
-class TestTxt2ImgWorking(unittest.TestCase):
- def setUp(self):
- self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
- self.simple_txt2img = {
- "enable_hr": False,
- "denoising_strength": 0,
- "firstphase_width": 0,
- "firstphase_height": 0,
- "prompt": "example prompt",
- "styles": [],
- "seed": -1,
- "subseed": -1,
- "subseed_strength": 0,
- "seed_resize_from_h": -1,
- "seed_resize_from_w": -1,
- "batch_size": 1,
- "n_iter": 1,
- "steps": 3,
- "cfg_scale": 7,
- "width": 64,
- "height": 64,
- "restore_faces": False,
- "tiling": False,
- "negative_prompt": "",
- "eta": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 1,
- "sampler_index": "Euler a"
- }
-
- def test_txt2img_simple_performed(self):
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_negative_prompt_performed(self):
- self.simple_txt2img["negative_prompt"] = "example negative prompt"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_complex_prompt_performed(self):
- self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_not_square_image_performed(self):
- self.simple_txt2img["height"] = 128
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_hrfix_performed(self):
- self.simple_txt2img["enable_hr"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_tiling_performed(self):
- self.simple_txt2img["tiling"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_restore_faces_performed(self):
- self.simple_txt2img["restore_faces"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_vanilla_sampler_performed(self):
- self.simple_txt2img["sampler_index"] = "PLMS"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
- self.simple_txt2img["sampler_index"] = "DDIM"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_multiple_batches_performed(self):
- self.simple_txt2img["n_iter"] = 2
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_batch_performed(self):
- self.simple_txt2img["batch_size"] = 2
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
deleted file mode 100644
index 0bfc28a0d30c070c292ff8154e9b93a74abecb85..0000000000000000000000000000000000000000
--- a/test/basic_features/utils_test.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import unittest
-import requests
-
-class UtilsTests(unittest.TestCase):
- def setUp(self):
- self.url_options = "http://localhost:7860/sdapi/v1/options"
- self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
- self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
- self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
- self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
- self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
- self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
- self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
- self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
- self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
-
- def test_options_get(self):
- self.assertEqual(requests.get(self.url_options).status_code, 200)
-
- def test_options_write(self):
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
-
- pre_value = response.json()["send_seed"]
-
- self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
-
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json()["send_seed"], not pre_value)
-
- requests.post(self.url_options, json={"send_seed": pre_value})
-
- def test_cmd_flags(self):
- self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
-
- def test_samplers(self):
- self.assertEqual(requests.get(self.url_samplers).status_code, 200)
-
- def test_upscalers(self):
- self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
-
- def test_sd_models(self):
- self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
-
- def test_hypernetworks(self):
- self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
-
- def test_face_restorers(self):
- self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
-
- def test_realesrgan_models(self):
- self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
-
- def test_prompt_styles(self):
- self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
-
- def test_embeddings(self):
- self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/server_poll.py b/test/server_poll.py
deleted file mode 100644
index 42d56a4caacfc40d686dc99668d72238392448cd..0000000000000000000000000000000000000000
--- a/test/server_poll.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import unittest
-import requests
-import time
-
-
-def run_tests(proc, test_dir):
- timeout_threshold = 240
- start_time = time.time()
- while time.time()-start_time < timeout_threshold:
- try:
- requests.head("http://localhost:7860/")
- break
- except requests.exceptions.ConnectionError:
- if proc.poll() is not None:
- break
- if proc.poll() is None:
- if test_dir is None:
- test_dir = "test"
- suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
- result = unittest.TextTestRunner(verbosity=2).run(suite)
- return len(result.failures) + len(result.errors)
- else:
- print("Launch unsuccessful")
- return 1
diff --git a/test/test_files/empty.pt b/test/test_files/empty.pt
deleted file mode 100644
index 72f13b63cc3353c1c7cdc27ee23a2ad242cfdfba..0000000000000000000000000000000000000000
--- a/test/test_files/empty.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d030ad8db708280fcae77d87e973102039acd23a11bdecc3db8eb6c0ac940ee1
-size 431
diff --git a/test/test_files/img2img_basic.png b/test/test_files/img2img_basic.png
deleted file mode 100644
index 49a420482d0a70b9f5986d776a66cb3ea39d1a97..0000000000000000000000000000000000000000
Binary files a/test/test_files/img2img_basic.png and /dev/null differ
diff --git a/test/test_files/mask_basic.png b/test/test_files/mask_basic.png
deleted file mode 100644
index 0c2e9a6899e5c0381ce7c7364b31d684464ab423..0000000000000000000000000000000000000000
Binary files a/test/test_files/mask_basic.png and /dev/null differ
diff --git a/textual_inversion_templates/hypernetwork.txt b/textual_inversion_templates/hypernetwork.txt
deleted file mode 100644
index 91e06890571c7e4974d5a76c30fab62e8587c7d2..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/hypernetwork.txt
+++ /dev/null
@@ -1,27 +0,0 @@
-a photo of a [filewords]
-a rendering of a [filewords]
-a cropped photo of the [filewords]
-the photo of a [filewords]
-a photo of a clean [filewords]
-a photo of a dirty [filewords]
-a dark photo of the [filewords]
-a photo of my [filewords]
-a photo of the cool [filewords]
-a close-up photo of a [filewords]
-a bright photo of the [filewords]
-a cropped photo of a [filewords]
-a photo of the [filewords]
-a good photo of the [filewords]
-a photo of one [filewords]
-a close-up photo of the [filewords]
-a rendition of the [filewords]
-a photo of the clean [filewords]
-a rendition of a [filewords]
-a photo of a nice [filewords]
-a good photo of a [filewords]
-a photo of the nice [filewords]
-a photo of the small [filewords]
-a photo of the weird [filewords]
-a photo of the large [filewords]
-a photo of a cool [filewords]
-a photo of a small [filewords]
diff --git a/textual_inversion_templates/none.txt b/textual_inversion_templates/none.txt
deleted file mode 100644
index f77af4612b289a56b718c3bee62c66a6151f75be..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/none.txt
+++ /dev/null
@@ -1 +0,0 @@
-picture
diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt
deleted file mode 100644
index 15af2d6b85f259d0bf41fbe0c8ca7a3340e1b259..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/style.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-a painting, art by [name]
-a rendering, art by [name]
-a cropped painting, art by [name]
-the painting, art by [name]
-a clean painting, art by [name]
-a dirty painting, art by [name]
-a dark painting, art by [name]
-a picture, art by [name]
-a cool painting, art by [name]
-a close-up painting, art by [name]
-a bright painting, art by [name]
-a cropped painting, art by [name]
-a good painting, art by [name]
-a close-up painting, art by [name]
-a rendition, art by [name]
-a nice painting, art by [name]
-a small painting, art by [name]
-a weird painting, art by [name]
-a large painting, art by [name]
diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt
deleted file mode 100644
index b3a8159a869a7890bdd42470664fadf015e0658d..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/style_filewords.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-a painting of [filewords], art by [name]
-a rendering of [filewords], art by [name]
-a cropped painting of [filewords], art by [name]
-the painting of [filewords], art by [name]
-a clean painting of [filewords], art by [name]
-a dirty painting of [filewords], art by [name]
-a dark painting of [filewords], art by [name]
-a picture of [filewords], art by [name]
-a cool painting of [filewords], art by [name]
-a close-up painting of [filewords], art by [name]
-a bright painting of [filewords], art by [name]
-a cropped painting of [filewords], art by [name]
-a good painting of [filewords], art by [name]
-a close-up painting of [filewords], art by [name]
-a rendition of [filewords], art by [name]
-a nice painting of [filewords], art by [name]
-a small painting of [filewords], art by [name]
-a weird painting of [filewords], art by [name]
-a large painting of [filewords], art by [name]
diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt
deleted file mode 100644
index 79f36aa0543fc2151b7f7e28725309c0c9a4912a..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/subject.txt
+++ /dev/null
@@ -1,27 +0,0 @@
-a photo of a [name]
-a rendering of a [name]
-a cropped photo of the [name]
-the photo of a [name]
-a photo of a clean [name]
-a photo of a dirty [name]
-a dark photo of the [name]
-a photo of my [name]
-a photo of the cool [name]
-a close-up photo of a [name]
-a bright photo of the [name]
-a cropped photo of a [name]
-a photo of the [name]
-a good photo of the [name]
-a photo of one [name]
-a close-up photo of the [name]
-a rendition of the [name]
-a photo of the clean [name]
-a rendition of a [name]
-a photo of a nice [name]
-a good photo of a [name]
-a photo of the nice [name]
-a photo of the small [name]
-a photo of the weird [name]
-a photo of the large [name]
-a photo of a cool [name]
-a photo of a small [name]
diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt
deleted file mode 100644
index 008652a6bf4277f12a1759f5f3c815ae754dcfcf..0000000000000000000000000000000000000000
--- a/textual_inversion_templates/subject_filewords.txt
+++ /dev/null
@@ -1,27 +0,0 @@
-a photo of a [name], [filewords]
-a rendering of a [name], [filewords]
-a cropped photo of the [name], [filewords]
-the photo of a [name], [filewords]
-a photo of a clean [name], [filewords]
-a photo of a dirty [name], [filewords]
-a dark photo of the [name], [filewords]
-a photo of my [name], [filewords]
-a photo of the cool [name], [filewords]
-a close-up photo of a [name], [filewords]
-a bright photo of the [name], [filewords]
-a cropped photo of a [name], [filewords]
-a photo of the [name], [filewords]
-a good photo of the [name], [filewords]
-a photo of one [name], [filewords]
-a close-up photo of the [name], [filewords]
-a rendition of the [name], [filewords]
-a photo of the clean [name], [filewords]
-a rendition of a [name], [filewords]
-a photo of a nice [name], [filewords]
-a good photo of a [name], [filewords]
-a photo of the nice [name], [filewords]
-a photo of the small [name], [filewords]
-a photo of the weird [name], [filewords]
-a photo of the large [name], [filewords]
-a photo of a cool [name], [filewords]
-a photo of a small [name], [filewords]
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
deleted file mode 100644
index 37cac4fb02c5ea53fa270336729b90ac2ef3ac74..0000000000000000000000000000000000000000
--- a/webui-macos-env.sh
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/bin/bash
-####################################################################
-# macOS defaults #
-# Please modify webui-user.sh to change these instead of this file #
-####################################################################
-
-if [[ -x "$(command -v python3.10)" ]]
-then
- python_cmd="python3.10"
-fi
-
-export install_dir="$HOME"
-export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
-export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
-export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
-export PYTORCH_ENABLE_MPS_FALLBACK=1
-
-####################################################################
diff --git a/webui-user.bat b/webui-user.bat
deleted file mode 100644
index e5a257bef06f5bfcaff1c8b33c64a767eb8b3fe5..0000000000000000000000000000000000000000
--- a/webui-user.bat
+++ /dev/null
@@ -1,8 +0,0 @@
-@echo off
-
-set PYTHON=
-set GIT=
-set VENV_DIR=
-set COMMANDLINE_ARGS=
-
-call webui.bat
diff --git a/webui-user.sh b/webui-user.sh
deleted file mode 100644
index bfa53cb7c67083ec0a01bfa420269af4d85c6c94..0000000000000000000000000000000000000000
--- a/webui-user.sh
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/bin/bash
-#########################################################
-# Uncomment and change the variables below to your need:#
-#########################################################
-
-# Install directory without trailing slash
-#install_dir="/home/$(whoami)"
-
-# Name of the subdirectory
-#clone_dir="stable-diffusion-webui"
-
-# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
-#export COMMANDLINE_ARGS=""
-
-# python3 executable
-#python_cmd="python3"
-
-# git executable
-#export GIT="git"
-
-# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
-#venv_dir="venv"
-
-# script to launch to start the app
-#export LAUNCH_SCRIPT="launch.py"
-
-# install command for torch
-#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
-
-# Requirements file to use for stable-diffusion-webui
-#export REQS_FILE="requirements_versions.txt"
-
-# Fixed git repos
-#export K_DIFFUSION_PACKAGE=""
-#export GFPGAN_PACKAGE=""
-
-# Fixed git commits
-#export STABLE_DIFFUSION_COMMIT_HASH=""
-#export TAMING_TRANSFORMERS_COMMIT_HASH=""
-#export CODEFORMER_COMMIT_HASH=""
-#export BLIP_COMMIT_HASH=""
-
-# Uncomment to enable accelerated launch
-#export ACCELERATE="True"
-
-###########################################
diff --git a/webui.bat b/webui.bat
deleted file mode 100644
index 209d972bd755293bc16b4c8b6fe0e5108b3244b3..0000000000000000000000000000000000000000
--- a/webui.bat
+++ /dev/null
@@ -1,85 +0,0 @@
-@echo off
-
-if not defined PYTHON (set PYTHON=python)
-if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
-
-
-set ERROR_REPORTING=FALSE
-
-mkdir tmp 2>NUL
-
-%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :check_pip
-echo Couldn't launch python
-goto :show_stdout_stderr
-
-:check_pip
-%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
-if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
-%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
-echo Couldn't install pip
-goto :show_stdout_stderr
-
-:start_venv
-if ["%VENV_DIR%"] == ["-"] goto :skip_venv
-if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
-
-dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :activate_venv
-
-for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
-echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
-%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :activate_venv
-echo Unable to create venv in directory "%VENV_DIR%"
-goto :show_stdout_stderr
-
-:activate_venv
-set PYTHON="%VENV_DIR%\Scripts\Python.exe"
-echo venv %PYTHON%
-
-:skip_venv
-if [%ACCELERATE%] == ["True"] goto :accelerate
-goto :launch
-
-:accelerate
-echo Checking for accelerate
-set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
-if EXIST %ACCELERATE% goto :accelerate_launch
-
-:launch
-%PYTHON% launch.py %*
-pause
-exit /b
-
-:accelerate_launch
-echo Accelerating
-%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
-pause
-exit /b
-
-:show_stdout_stderr
-
-echo.
-echo exit code: %errorlevel%
-
-for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
-if %size% equ 0 goto :show_stderr
-echo.
-echo stdout:
-type tmp\stdout.txt
-
-:show_stderr
-for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
-if %size% equ 0 goto :show_stderr
-echo.
-echo stderr:
-type tmp\stderr.txt
-
-:endofscript
-
-echo.
-echo Launch unsuccessful. Exiting.
-pause
diff --git a/webui.py b/webui.py
deleted file mode 100644
index 32acbe0f2dcba652148e0d3a67563ece172c7bf3..0000000000000000000000000000000000000000
--- a/webui.py
+++ /dev/null
@@ -1,287 +0,0 @@
-import os
-import sys
-import time
-import importlib
-import signal
-import re
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.middleware.gzip import GZipMiddleware
-from packaging import version
-
-import logging
-logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-
-from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
-from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
-
-import torch
-
-# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
-if ".dev" in torch.__version__ or "+git" in torch.__version__:
- torch.__long_version__ = torch.__version__
- torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
-import modules.codeformer_model as codeformer
-import modules.face_restoration
-import modules.gfpgan_model as gfpgan
-import modules.img2img
-
-import modules.lowvram
-import modules.paths
-import modules.scripts
-import modules.sd_hijack
-import modules.sd_models
-import modules.sd_vae
-import modules.txt2img
-import modules.script_callbacks
-import modules.textual_inversion.textual_inversion
-import modules.progress
-
-import modules.ui
-from modules import modelloader
-from modules.shared import cmd_opts
-import modules.hypernetworks.hypernetwork
-
-
-if cmd_opts.server_name:
- server_name = cmd_opts.server_name
-else:
- server_name = "0.0.0.0" if cmd_opts.listen else None
-
-
-def check_versions():
- if shared.cmd_opts.skip_version_check:
- return
-
- expected_torch_version = "1.13.1"
-
- if version.parse(torch.__version__) < version.parse(expected_torch_version):
- errors.print_error_explanation(f"""
-You are running torch {torch.__version__}.
-The program is tested to work with torch {expected_torch_version}.
-To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded, as well as
-there are reports of issues with training tab on the latest version.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
- expected_xformers_version = "0.0.16rc425"
- if shared.xformers_available:
- import xformers
-
- if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
- errors.print_error_explanation(f"""
-You are running xformers {xformers.__version__}.
-The program is tested to work with xformers {expected_xformers_version}.
-To reinstall the desired version, run with commandline flag --reinstall-xformers.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
-
-def initialize():
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
-
- if cmd_opts.ui_debug_mode:
- shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
- modules.scripts.load_scripts()
- return
-
- modelloader.cleanup_models()
- modules.sd_models.setup_model()
- codeformer.setup_model(cmd_opts.codeformer_models_path)
- gfpgan.setup_model(cmd_opts.gfpgan_models_path)
-
- modelloader.list_builtin_upscalers()
- modules.scripts.load_scripts()
- modelloader.load_upscalers()
-
- modules.sd_vae.refresh_vae_list()
-
- modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
-
- try:
- modules.sd_models.load_model()
- except Exception as e:
- errors.display(e, "loading stable diffusion model")
- print("", file=sys.stderr)
- print("Stable diffusion model failed to load, exiting", file=sys.stderr)
- exit(1)
-
- shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
-
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
- shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
-
- shared.reload_hypernetworks()
-
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
-
- extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
-
- if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
-
- try:
- if not os.path.exists(cmd_opts.tls_keyfile):
- print("Invalid path to TLS keyfile given")
- if not os.path.exists(cmd_opts.tls_certfile):
- print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
- except TypeError:
- cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
- print("TLS setup invalid, running webui without TLS")
- else:
- print("Running with TLS")
-
- # make the program just exit at ctrl+c without waiting for anything
- def sigint_handler(sig, frame):
- print(f'Interrupted with signal {sig} in {frame}')
- os._exit(0)
-
- signal.signal(signal.SIGINT, sigint_handler)
-
-
-def setup_cors(app):
- if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
-
-
-def create_api(app):
- from modules.api.api import Api
- api = Api(app, queue_lock)
- return api
-
-
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
- break
-
-
-def api_only():
- initialize()
-
- app = FastAPI()
- setup_cors(app)
- app.add_middleware(GZipMiddleware, minimum_size=1000)
- api = create_api(app)
-
- modules.script_callbacks.app_started_callback(None, app)
-
- api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
-
-def webui():
- launch_api = cmd_opts.api
- initialize()
-
- while 1:
- if shared.opts.clean_temp_dir_at_start:
- ui_tempdir.cleanup_tmpdr()
-
- modules.script_callbacks.before_ui_callback()
-
- shared.demo = modules.ui.create_ui()
- shared.demo.queue(concurrency_count=999999,status_update_rate=0.1)
-
- if cmd_opts.gradio_queue:
- shared.demo.queue(64)
-
- gradio_auth_creds = []
- if cmd_opts.gradio_auth:
- gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
- if cmd_opts.gradio_auth_path:
- with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
- for line in file.readlines():
- gradio_auth_creds += [x.strip() for x in line.split(',')]
-
- app, local_url, share_url = shared.demo.launch(
- share=cmd_opts.share,
- server_name=server_name,
- server_port=cmd_opts.port,
- ssl_keyfile=cmd_opts.tls_keyfile,
- ssl_certfile=cmd_opts.tls_certfile,
- debug=cmd_opts.gradio_debug,
- auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
- inbrowser=cmd_opts.autolaunch,
- prevent_thread_lock=True
- )
- # after initial launch, disable --autolaunch for subsequent restarts
- cmd_opts.autolaunch = False
-
- # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
- # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
- # running web ui and do whatever the attacker wants, including installing an extension and
- # running its code. We disable this here. Suggested by RyotaK.
- app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
-
- setup_cors(app)
-
- app.add_middleware(GZipMiddleware, minimum_size=1000)
-
- modules.progress.setup_progress_api(app)
-
- if launch_api:
- create_api(app)
-
- ui_extra_networks.add_pages_to_demo(app)
-
- modules.script_callbacks.app_started_callback(shared.demo, app)
-
- wait_on_server(shared.demo)
- print('Restarting UI...')
-
- sd_samplers.set_samplers()
-
- modules.script_callbacks.script_unloaded_callback()
- extensions.list_extensions()
-
- localization.list_localizations(cmd_opts.localizations_dir)
-
- modelloader.forbid_loaded_nonbuiltin_upscalers()
- modules.scripts.reload_scripts()
- modules.script_callbacks.model_loaded_callback(shared.sd_model)
- modelloader.load_upscalers()
-
- for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
- importlib.reload(module)
-
- modules.sd_models.list_models()
-
- shared.reload_hypernetworks()
-
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
-
- extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
-
-
-if __name__ == "__main__":
- if cmd_opts.nowebui:
- api_only()
- else:
- webui()