|
from abc import ABC, abstractmethod |
|
|
|
import numpy as np |
|
import torch as th |
|
import torch.distributed as dist |
|
|
|
|
|
def create_named_schedule_sampler(name, diffusion): |
|
""" |
|
Create a ScheduleSampler from a library of pre-defined samplers. |
|
|
|
:param name: the name of the sampler. |
|
:param diffusion: the diffusion object to sample for. |
|
""" |
|
if name == "uniform": |
|
return UniformSampler(diffusion) |
|
else: |
|
raise NotImplementedError(f"unknown schedule sampler: {name}") |
|
|
|
|
|
class ScheduleSampler(ABC): |
|
""" |
|
A distribution over timesteps in the diffusion process, intended to reduce |
|
variance of the objective. |
|
|
|
By default, samplers perform unbiased importance sampling, in which the |
|
objective's mean is unchanged. |
|
However, subclasses may override sample() to change how the resampled |
|
terms are reweighted, allowing for actual changes in the objective. |
|
""" |
|
@abstractmethod |
|
def weights(self): |
|
""" |
|
Get a numpy array of weights, one per diffusion step. |
|
|
|
The weights needn't be normalized, but must be positive. |
|
""" |
|
|
|
def sample(self, batch_size, device): |
|
""" |
|
Importance-sample timesteps for a batch. |
|
|
|
:param batch_size: the number of timesteps. |
|
:param device: the torch device to save to. |
|
:return: a tuple (timesteps, weights): |
|
- timesteps: a tensor of timestep indices. |
|
- weights: a tensor of weights to scale the resulting losses. |
|
""" |
|
w = self.weights() |
|
p = w / np.sum(w) |
|
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) |
|
indices = th.from_numpy(indices_np).long().to(device) |
|
weights_np = 1 / (len(p) * p[indices_np]) |
|
weights = th.from_numpy(weights_np).float().to(device) |
|
return indices, weights |
|
|
|
|
|
class UniformSampler(ScheduleSampler): |
|
def __init__(self, num_timesteps): |
|
self._weights = np.ones([num_timesteps]) |
|
|
|
def weights(self): |
|
return self._weights |
|
|