Spaces:
Running
Running
File size: 2,143 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import math
import torch
def identity(t, *args, **kwargs):
"""return t"""
return t
def exists(x):
"""whether x is None or not"""
return x is not None
def default(val, d):
"""ternary judgment: val != None ? val : d"""
if exists(val):
return val
return d() if callable(d) else d
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
#################################################################################
# Model Utils #
#################################################################################
def sum_params(model: torch.nn.Module, eps: float = 1e6):
return sum(p.numel() for p in model.parameters()) / eps
#################################################################################
# DataLoader Utils #
#################################################################################
def cycle(dl):
while True:
for data in dl:
yield data
#################################################################################
# Diffusion Model Utils #
#################################################################################
def extract(a, t, x_shape):
b, *_ = t.shape
assert x_shape[0] == b
out = a.gather(-1, t) # 1-D tensor, shape: (b,)
return out.reshape(b, *((1,) * (len(x_shape) - 1))) # shape: [b, 1, 1, 1]
def unnormalize(x):
"""unnormalize_to_zero_to_one"""
x = (x + 1) * 0.5 # Map the data interval to [0, 1]
return torch.clamp(x, 0.0, 1.0)
def normalize(x):
"""normalize_to_neg_one_to_one"""
x = x * 2 - 1 # Map the data interval to [-1, 1]
return torch.clamp(x, -1.0, 1.0)
|