DiffmorpherXAp-adapter / utils /model_utils.py
CSH-1220
Add application file
d57e374
import torch
import torch.nn.functional as F
from torchvision import transforms
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
if len(size) == 3:
feat_std = feat_var.sqrt().view(N, C, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1)
else:
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def get_img(img, resolution=512):
norm_mean = [0.5, 0.5, 0.5]
norm_std = [0.5, 0.5, 0.5]
transform = transforms.Compose([
transforms.Resize((resolution, resolution)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
img = transform(img)
return img.unsqueeze(0)
@torch.no_grad()
def slerp(p0, p1, fract_mixing: float, adain=True):
r""" Copied from lunarring/latentblending
Helper function to correctly mix two random variables using spherical interpolation.
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
if adain:
mean1, std1 = calc_mean_std(p0)
mean2, std2 = calc_mean_std(p1)
mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing
std = std1 * (1 - fract_mixing) + std2 * fract_mixing
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if adain:
interp = F.instance_norm(interp) * std + mean
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def do_replace_attn(key: str):
# return key.startswith('up_blocks.2') or key.startswith('up_blocks.3')
return key.startswith('up')