Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
import math | |
import torch | |
def identity(t, *args, **kwargs): | |
"""return t""" | |
return t | |
def exists(x): | |
"""whether x is None or not""" | |
return x is not None | |
def default(val, d): | |
"""ternary judgment: val != None ? val : d""" | |
if exists(val): | |
return val | |
return d() if callable(d) else d | |
def has_int_squareroot(num): | |
return (math.sqrt(num) ** 2) == num | |
def num_to_groups(num, divisor): | |
groups = num // divisor | |
remainder = num % divisor | |
arr = [divisor] * groups | |
if remainder > 0: | |
arr.append(remainder) | |
return arr | |
################################################################################# | |
# Model Utils # | |
################################################################################# | |
def sum_params(model: torch.nn.Module, eps: float = 1e6): | |
return sum(p.numel() for p in model.parameters()) / eps | |
################################################################################# | |
# DataLoader Utils # | |
################################################################################# | |
def cycle(dl): | |
while True: | |
for data in dl: | |
yield data | |
################################################################################# | |
# Diffusion Model Utils # | |
################################################################################# | |
def extract(a, t, x_shape): | |
b, *_ = t.shape | |
assert x_shape[0] == b | |
out = a.gather(-1, t) # 1-D tensor, shape: (b,) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) # shape: [b, 1, 1, 1] | |
def unnormalize(x): | |
"""unnormalize_to_zero_to_one""" | |
x = (x + 1) * 0.5 # Map the data interval to [0, 1] | |
return torch.clamp(x, 0.0, 1.0) | |
def normalize(x): | |
"""normalize_to_neg_one_to_one""" | |
x = x * 2 - 1 # Map the data interval to [-1, 1] | |
return torch.clamp(x, -1.0, 1.0) | |