function_name,docstring,function_body,file_path
diffusion_from_config,,"def diffusion_from_config(config: Dict[str, Any]) ->GaussianDiffusion:
    schedule = config['schedule']
    steps = config['timesteps']
    respace = config.get('respacing', None)
    mean_type = config.get('mean_type', 'epsilon')
    betas = get_named_beta_schedule(schedule, steps)
    channel_scales = config.get('channel_scales', None)
    channel_biases = config.get('channel_biases', None)
    if channel_scales is not None:
        channel_scales = np.array(channel_scales)
    if channel_biases is not None:
        channel_biases = np.array(channel_biases)
    kwargs = dict(betas=betas, model_mean_type=mean_type, model_var_type=
        'learned_range', loss_type='mse', channel_scales=channel_scales,
        channel_biases=channel_biases)
    if respace is None:
        return GaussianDiffusion(**kwargs)
    else:
        return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace
            ), **kwargs)
",point_e\diffusion\configs.py
get_beta_schedule,"This is the deprecated API for creating beta schedules.

See get_named_beta_schedule() for the new library of schedules.","def get_beta_schedule(beta_schedule, *, beta_start, beta_end,
    num_diffusion_timesteps):
    """"""""""""
    if beta_schedule == 'linear':
        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps,
            dtype=np.float64)
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas
",point_e\diffusion\gaussian_diffusion.py
get_named_beta_schedule,"Get a pre-defined beta schedule for the given name.

The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.","def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """"""""""""
    if schedule_name == 'linear':
        scale = 1000 / num_diffusion_timesteps
        return get_beta_schedule('linear', beta_start=scale * 0.0001,
            beta_end=scale * 0.02, num_diffusion_timesteps=
            num_diffusion_timesteps)
    elif schedule_name == 'cosine':
        return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.
            cos((t + 0.008) / 1.008 * math.pi / 2) ** 2)
    else:
        raise NotImplementedError(f'unknown beta schedule: {schedule_name}')
",point_e\diffusion\gaussian_diffusion.py
betas_for_alpha_bar,"Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].

:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                  produces the cumulative product of (1-beta) up to that
                  part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
                 prevent singularities.","def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """"""""""""
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)
",point_e\diffusion\gaussian_diffusion.py
space_timesteps,"Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
:param num_timesteps: the number of diffusion steps in the original
                      process to divide up.
:param section_counts: either a list of numbers, or a string containing
                       comma-separated numbers, indicating the step count
                       per section. As a special case, use ""ddimN"" where N
                       is a number of steps to use the striding from the
                       DDIM paper.
:return: a set of diffusion steps from the original process to use.","def space_timesteps(num_timesteps, section_counts):
    """"""""""""
    if isinstance(section_counts, str):
        if section_counts.startswith('ddim'):
            desired_count = int(section_counts[len('ddim'):])
            for i in range(1, num_timesteps):
                if len(range(0, num_timesteps, i)) == desired_count:
                    return set(range(0, num_timesteps, i))
            raise ValueError(
                f'cannot create exactly {num_timesteps} steps with an integer stride'
                )
        elif section_counts.startswith('exact'):
            res = set(int(x) for x in section_counts[len('exact'):].split(','))
            for x in res:
                if x < 0 or x >= num_timesteps:
                    raise ValueError(f'timestep out of bounds: {x}')
            return res
        section_counts = [int(x) for x in section_counts.split(',')]
    size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f'cannot divide section of {size} steps into {section_count}')
        if section_count <= 1:
            frac_stride = 1
        else:
            frac_stride = (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)
",point_e\diffusion\gaussian_diffusion.py
_extract_into_tensor,"Extract values from a 1-D numpy array for a batch of indices.

:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
                        dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.","def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """"""""""""
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + th.zeros(broadcast_shape, device=timesteps.device)
",point_e\diffusion\gaussian_diffusion.py
normal_kl,"Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.","def normal_kl(mean1, logvar1, mean2, logvar2):
    """"""""""""
    tensor = None
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, th.Tensor):
            tensor = obj
            break
    assert tensor is not None, 'at least one argument must be a Tensor'
    logvar1, logvar2 = [(x if isinstance(x, th.Tensor) else th.tensor(x).to
        (tensor)) for x in (logvar1, logvar2)]
    return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + (
        mean1 - mean2) ** 2 * th.exp(-logvar2))
",point_e\diffusion\gaussian_diffusion.py
approx_standard_normal_cdf,"A fast approximation of the cumulative distribution function of the
standard normal.","def approx_standard_normal_cdf(x):
    """"""""""""
    return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.
        pow(x, 3))))
",point_e\diffusion\gaussian_diffusion.py
discretized_gaussian_log_likelihood,"Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
          rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).","def discretized_gaussian_log_likelihood(x, *, means, log_scales):
    """"""""""""
    assert x.shape == means.shape == log_scales.shape
    centered_x = x - means
    inv_stdv = th.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1.0 / 255.0)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999,
        log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))))
    assert log_probs.shape == x.shape
    return log_probs
",point_e\diffusion\gaussian_diffusion.py
mean_flat,Take the mean over all non-batch dimensions.,"def mean_flat(tensor):
    """"""""""""
    return tensor.flatten(1).mean(1)
",point_e\diffusion\gaussian_diffusion.py
__init__,,"def __init__(self, *, betas: Sequence[float], model_mean_type: str,
    model_var_type: str, loss_type: str, discretized_t0: bool=False,
    channel_scales: Optional[np.ndarray]=None, channel_biases: Optional[np.
    ndarray]=None):
    self.model_mean_type = model_mean_type
    self.model_var_type = model_var_type
    self.loss_type = loss_type
    self.discretized_t0 = discretized_t0
    self.channel_scales = channel_scales
    self.channel_biases = channel_biases
    betas = np.array(betas, dtype=np.float64)
    self.betas = betas
    assert len(betas.shape) == 1, 'betas must be 1-D'
    assert (betas > 0).all() and (betas <= 1).all()
    self.num_timesteps = int(betas.shape[0])
    alphas = 1.0 - betas
    self.alphas_cumprod = np.cumprod(alphas, axis=0)
    self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
    self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
    assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
    self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
    self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
    self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
    self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
    self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
    self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
        1.0 - self.alphas_cumprod)
    self.posterior_log_variance_clipped = np.log(np.append(self.
        posterior_variance[1], self.posterior_variance[1:]))
    self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (
        1.0 - self.alphas_cumprod)
    self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(
        alphas) / (1.0 - self.alphas_cumprod)
",point_e\diffusion\gaussian_diffusion.py
get_sigmas,,"def get_sigmas(self, t):
    return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)
",point_e\diffusion\gaussian_diffusion.py
q_mean_variance,"Get the distribution q(x_t | x_0).

:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.","def q_mean_variance(self, x_start, t):
    """"""""""""
    mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape
        ) * x_start
    variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape
        )
    log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
        t, x_start.shape)
    return mean, variance, log_variance
",point_e\diffusion\gaussian_diffusion.py
q_sample,"Diffuse the data for a given number of diffusion steps.

In other words, sample from q(x_t | x_0).

:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.","def q_sample(self, x_start, t, noise=None):
    """"""""""""
    if noise is None:
        noise = th.randn_like(x_start)
    assert noise.shape == x_start.shape
    return _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape
        ) * x_start + _extract_into_tensor(self.
        sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
",point_e\diffusion\gaussian_diffusion.py
q_posterior_mean_variance,"Compute the mean and variance of the diffusion posterior:

    q(x_{t-1} | x_t, x_0)","def q_posterior_mean_variance(self, x_start, x_t, t):
    """"""""""""
    assert x_start.shape == x_t.shape
    posterior_mean = _extract_into_tensor(self.posterior_mean_coef1, t, x_t
        .shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2,
        t, x_t.shape) * x_t
    posterior_variance = _extract_into_tensor(self.posterior_variance, t,
        x_t.shape)
    posterior_log_variance_clipped = _extract_into_tensor(self.
        posterior_log_variance_clipped, t, x_t.shape)
    assert posterior_mean.shape[0] == posterior_variance.shape[0
        ] == posterior_log_variance_clipped.shape[0] == x_start.shape[0]
    return posterior_mean, posterior_variance, posterior_log_variance_clipped
",point_e\diffusion\gaussian_diffusion.py
p_mean_variance,"Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.

:param model: the model, which takes a signal and a batch of timesteps
              as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
    x_start prediction before it is used to sample. Applies before
    clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
    pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
         - 'mean': the model mean output.
         - 'variance': the model variance output.
         - 'log_variance': the log of 'variance'.
         - 'pred_xstart': the prediction for x_0.","def p_mean_variance(self, model, x, t, clip_denoised=False, denoised_fn=
    None, model_kwargs=None):
    """"""""""""
    if model_kwargs is None:
        model_kwargs = {}
    B, C = x.shape[:2]
    assert t.shape == (B,)
    model_output = model(x, t, **model_kwargs)
    if isinstance(model_output, tuple):
        model_output, extra = model_output
    else:
        extra = None
    if self.model_var_type in ['learned', 'learned_range']:
        assert model_output.shape == (B, C * 2, *x.shape[2:])
        model_output, model_var_values = th.split(model_output, C, dim=1)
        if self.model_var_type == 'learned':
            model_log_variance = model_var_values
            model_variance = th.exp(model_log_variance)
        else:
            min_log = _extract_into_tensor(self.
                posterior_log_variance_clipped, t, x.shape)
            max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
            frac = (model_var_values + 1) / 2
            model_log_variance = frac * max_log + (1 - frac) * min_log
            model_variance = th.exp(model_log_variance)
    else:
        model_variance, model_log_variance = {'fixed_large': (np.append(
            self.posterior_variance[1], self.betas[1:]), np.log(np.append(
            self.posterior_variance[1], self.betas[1:]))), 'fixed_small': (
            self.posterior_variance, self.posterior_log_variance_clipped)}[self
            .model_var_type]
        model_variance = _extract_into_tensor(model_variance, t, x.shape)
        model_log_variance = _extract_into_tensor(model_log_variance, t, x.
            shape)

    def process_xstart(x):
        if denoised_fn is not None:
            x = denoised_fn(x)
        if clip_denoised:
            return x.clamp(-1, 1)
        return x
    if self.model_mean_type == 'x_prev':
        pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x,
            t=t, xprev=model_output))
        model_mean = model_output
    elif self.model_mean_type in ['x_start', 'epsilon']:
        if self.model_mean_type == 'x_start':
            pred_xstart = process_xstart(model_output)
        else:
            pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=
                x, t=t, eps=model_output))
        model_mean, _, _ = self.q_posterior_mean_variance(x_start=
            pred_xstart, x_t=x, t=t)
    else:
        raise NotImplementedError(self.model_mean_type)
    assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
    return {'mean': model_mean, 'variance': model_variance, 'log_variance':
        model_log_variance, 'pred_xstart': pred_xstart, 'extra': extra}
",point_e\diffusion\gaussian_diffusion.py
_predict_xstart_from_eps,,"def _predict_xstart_from_eps(self, x_t, t, eps):
    assert x_t.shape == eps.shape
    return _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape
        ) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
        x_t.shape) * eps
",point_e\diffusion\gaussian_diffusion.py
_predict_xstart_from_xprev,,"def _predict_xstart_from_xprev(self, x_t, t, xprev):
    assert x_t.shape == xprev.shape
    return _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape
        ) * xprev - _extract_into_tensor(self.posterior_mean_coef2 / self.
        posterior_mean_coef1, t, x_t.shape) * x_t
",point_e\diffusion\gaussian_diffusion.py
_predict_eps_from_xstart,,"def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
    return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.
        shape) * x_t - pred_xstart) / _extract_into_tensor(self.
        sqrt_recipm1_alphas_cumprod, t, x_t.shape)
",point_e\diffusion\gaussian_diffusion.py
condition_mean,"Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.

This uses the conditioning strategy from Sohl-Dickstein et al. (2015).","def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    """"""""""""
    gradient = cond_fn(x, t, **model_kwargs)
    new_mean = p_mean_var['mean'].float() + p_mean_var['variance'
        ] * gradient.float()
    return new_mean
",point_e\diffusion\gaussian_diffusion.py
condition_score,"Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.

See condition_mean() for details on cond_fn.

Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).","def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    """"""""""""
    alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
    eps = self._predict_eps_from_xstart(x, t, p_mean_var['pred_xstart'])
    eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
    out = p_mean_var.copy()
    out['pred_xstart'] = self._predict_xstart_from_eps(x, t, eps)
    out['mean'], _, _ = self.q_posterior_mean_variance(x_start=out[
        'pred_xstart'], x_t=x, t=t)
    return out
",point_e\diffusion\gaussian_diffusion.py
p_sample,"Sample x_{t-1} from the model at the given timestep.

:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
    x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
                similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
    pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
         - 'sample': a random sample from the model.
         - 'pred_xstart': a prediction of x_0.","def p_sample(self, model, x, t, clip_denoised=False, denoised_fn=None,
    cond_fn=None, model_kwargs=None):
    """"""""""""
    out = self.p_mean_variance(model, x, t, clip_denoised=clip_denoised,
        denoised_fn=denoised_fn, model_kwargs=model_kwargs)
    noise = th.randn_like(x)
    nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
    if cond_fn is not None:
        out['mean'] = self.condition_mean(cond_fn, out, x, t, model_kwargs=
            model_kwargs)
    sample = out['mean'] + nonzero_mask * th.exp(0.5 * out['log_variance']
        ) * noise
    return {'sample': sample, 'pred_xstart': out['pred_xstart']}
",point_e\diffusion\gaussian_diffusion.py
p_sample_loop,"Generate samples from the model.

:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
              Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
    x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
                similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
    pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
               If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.","def p_sample_loop(self, model, shape, noise=None, clip_denoised=False,
    denoised_fn=None, cond_fn=None, model_kwargs=None, device=None,
    progress=False, temp=1.0):
    """"""""""""
    final = None
    for sample in self.p_sample_loop_progressive(model, shape, noise=noise,
        clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=
        cond_fn, model_kwargs=model_kwargs, device=device, progress=
        progress, temp=temp):
        final = sample
    return final['sample']
",point_e\diffusion\gaussian_diffusion.py
p_sample_loop_progressive,"Generate samples from the model and yield intermediate samples from
each timestep of diffusion.

Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().","def p_sample_loop_progressive(self, model, shape, noise=None, clip_denoised
    =False, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None,
    progress=False, temp=1.0):
    """"""""""""
    if device is None:
        device = next(model.parameters()).device
    assert isinstance(shape, (tuple, list))
    if noise is not None:
        img = noise
    else:
        img = th.randn(*shape, device=device) * temp
    indices = list(range(self.num_timesteps))[::-1]
    if progress:
        from tqdm.auto import tqdm
        indices = tqdm(indices)
    for i in indices:
        t = th.tensor([i] * shape[0], device=device)
        with th.no_grad():
            out = self.p_sample(model, img, t, clip_denoised=clip_denoised,
                denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=
                model_kwargs)
            yield self.unscale_out_dict(out)
            img = out['sample']
",point_e\diffusion\gaussian_diffusion.py
ddim_sample,"Sample x_{t-1} from the model using DDIM.

Same usage as p_sample().","def ddim_sample(self, model, x, t, clip_denoised=False, denoised_fn=None,
    cond_fn=None, model_kwargs=None, eta=0.0):
    """"""""""""
    out = self.p_mean_variance(model, x, t, clip_denoised=clip_denoised,
        denoised_fn=denoised_fn, model_kwargs=model_kwargs)
    if cond_fn is not None:
        out = self.condition_score(cond_fn, out, x, t, model_kwargs=
            model_kwargs)
    eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])
    alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
    alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
    sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(
        1 - alpha_bar / alpha_bar_prev)
    noise = th.randn_like(x)
    mean_pred = out['pred_xstart'] * th.sqrt(alpha_bar_prev) + th.sqrt(1 -
        alpha_bar_prev - sigma ** 2) * eps
    nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
    sample = mean_pred + nonzero_mask * sigma * noise
    return {'sample': sample, 'pred_xstart': out['pred_xstart']}
",point_e\diffusion\gaussian_diffusion.py
ddim_reverse_sample,Sample x_{t+1} from the model using DDIM reverse ODE.,"def ddim_reverse_sample(self, model, x, t, clip_denoised=False, denoised_fn
    =None, cond_fn=None, model_kwargs=None, eta=0.0):
    """"""""""""
    assert eta == 0.0, 'Reverse ODE only for deterministic path'
    out = self.p_mean_variance(model, x, t, clip_denoised=clip_denoised,
        denoised_fn=denoised_fn, model_kwargs=model_kwargs)
    if cond_fn is not None:
        out = self.condition_score(cond_fn, out, x, t, model_kwargs=
            model_kwargs)
    eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) *
        x - out['pred_xstart']) / _extract_into_tensor(self.
        sqrt_recipm1_alphas_cumprod, t, x.shape)
    alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
    mean_pred = out['pred_xstart'] * th.sqrt(alpha_bar_next) + th.sqrt(1 -
        alpha_bar_next) * eps
    return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']}
",point_e\diffusion\gaussian_diffusion.py
ddim_sample_loop,"Generate samples from the model using DDIM.

Same usage as p_sample_loop().","def ddim_sample_loop(self, model, shape, noise=None, clip_denoised=False,
    denoised_fn=None, cond_fn=None, model_kwargs=None, device=None,
    progress=False, eta=0.0, temp=1.0):
    """"""""""""
    final = None
    for sample in self.ddim_sample_loop_progressive(model, shape, noise=
        noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn,
        cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress
        =progress, eta=eta, temp=temp):
        final = sample
    return final['sample']
",point_e\diffusion\gaussian_diffusion.py
ddim_sample_loop_progressive,"Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.

Same usage as p_sample_loop_progressive().","def ddim_sample_loop_progressive(self, model, shape, noise=None,
    clip_denoised=False, denoised_fn=None, cond_fn=None, model_kwargs=None,
    device=None, progress=False, eta=0.0, temp=1.0):
    """"""""""""
    if device is None:
        device = next(model.parameters()).device
    assert isinstance(shape, (tuple, list))
    if noise is not None:
        img = noise
    else:
        img = th.randn(*shape, device=device) * temp
    indices = list(range(self.num_timesteps))[::-1]
    if progress:
        from tqdm.auto import tqdm
        indices = tqdm(indices)
    for i in indices:
        t = th.tensor([i] * shape[0], device=device)
        with th.no_grad():
            out = self.ddim_sample(model, img, t, clip_denoised=
                clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn,
                model_kwargs=model_kwargs, eta=eta)
            yield self.unscale_out_dict(out)
            img = out['sample']
",point_e\diffusion\gaussian_diffusion.py
_vb_terms_bpd,"Get a term for the variational lower-bound.

The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.

:return: a dict with the following keys:
         - 'output': a shape [N] tensor of NLLs or KLs.
         - 'pred_xstart': the x_0 predictions.","def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False,
    model_kwargs=None):
    """"""""""""
    true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
        x_start=x_start, x_t=x_t, t=t)
    out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised,
        model_kwargs=model_kwargs)
    kl = normal_kl(true_mean, true_log_variance_clipped, out['mean'], out[
        'log_variance'])
    kl = mean_flat(kl) / np.log(2.0)
    decoder_nll = -discretized_gaussian_log_likelihood(x_start, means=out[
        'mean'], log_scales=0.5 * out['log_variance'])
    if not self.discretized_t0:
        decoder_nll = th.zeros_like(decoder_nll)
    assert decoder_nll.shape == x_start.shape
    decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
    output = th.where(t == 0, decoder_nll, kl)
    return {'output': output, 'pred_xstart': out['pred_xstart'], 'extra':
        out['extra']}
",point_e\diffusion\gaussian_diffusion.py
training_losses,"Compute training losses for a single timestep.

:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
    pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key ""loss"" containing a tensor of shape [N].
         Some mean or variance settings may also have other keys.","def training_losses(self, model, x_start, t, model_kwargs=None, noise=None
    ) ->Dict[str, th.Tensor]:
    """"""""""""
    x_start = self.scale_channels(x_start)
    if model_kwargs is None:
        model_kwargs = {}
    if noise is None:
        noise = th.randn_like(x_start)
    x_t = self.q_sample(x_start, t, noise=noise)
    terms = {}
    if self.loss_type == 'kl' or self.loss_type == 'rescaled_kl':
        vb_terms = self._vb_terms_bpd(model=model, x_start=x_start, x_t=x_t,
            t=t, clip_denoised=False, model_kwargs=model_kwargs)
        terms['loss'] = vb_terms['output']
        if self.loss_type == 'rescaled_kl':
            terms['loss'] *= self.num_timesteps
        extra = vb_terms['extra']
    elif self.loss_type == 'mse' or self.loss_type == 'rescaled_mse':
        model_output = model(x_t, t, **model_kwargs)
        if isinstance(model_output, tuple):
            model_output, extra = model_output
        else:
            extra = {}
        if self.model_var_type in ['learned', 'learned_range']:
            B, C = x_t.shape[:2]
            assert model_output.shape == (B, C * 2, *x_t.shape[2:])
            model_output, model_var_values = th.split(model_output, C, dim=1)
            frozen_out = th.cat([model_output.detach(), model_var_values],
                dim=1)
            terms['vb'] = self._vb_terms_bpd(model=lambda *args, r=
                frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised
                =False)['output']
            if self.loss_type == 'rescaled_mse':
                terms['vb'] *= self.num_timesteps / 1000.0
        target = {'x_prev': self.q_posterior_mean_variance(x_start=x_start,
            x_t=x_t, t=t)[0], 'x_start': x_start, 'epsilon': noise}[self.
            model_mean_type]
        assert model_output.shape == target.shape == x_start.shape
        terms['mse'] = mean_flat((target - model_output) ** 2)
        if 'vb' in terms:
            terms['loss'] = terms['mse'] + terms['vb']
        else:
            terms['loss'] = terms['mse']
    else:
        raise NotImplementedError(self.loss_type)
    if 'losses' in extra:
        terms.update({k: loss for k, (loss, _scale) in extra['losses'].items()}
            )
        for loss, scale in extra['losses'].values():
            terms['loss'] = terms['loss'] + loss * scale
    return terms
",point_e\diffusion\gaussian_diffusion.py
_prior_bpd,"Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.

This term can't be optimized, as it only depends on the encoder.

:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.","def _prior_bpd(self, x_start):
    """"""""""""
    batch_size = x_start.shape[0]
    t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
    qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
    kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0,
        logvar2=0.0)
    return mean_flat(kl_prior) / np.log(2.0)
",point_e\diffusion\gaussian_diffusion.py
calc_bpd_loop,"Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.

:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
    pass to the model. This can be used for conditioning.

:return: a dict containing the following keys:
         - total_bpd: the total variational lower-bound, per batch element.
         - prior_bpd: the prior term in the lower-bound.
         - vb: an [N x T] tensor of terms in the lower-bound.
         - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
         - mse: an [N x T] tensor of epsilon MSEs for each timestep.","def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None
    ):
    """"""""""""
    device = x_start.device
    batch_size = x_start.shape[0]
    vb = []
    xstart_mse = []
    mse = []
    for t in list(range(self.num_timesteps))[::-1]:
        t_batch = th.tensor([t] * batch_size, device=device)
        noise = th.randn_like(x_start)
        x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
        with th.no_grad():
            out = self._vb_terms_bpd(model, x_start=x_start, x_t=x_t, t=
                t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
                )
        vb.append(out['output'])
        xstart_mse.append(mean_flat((out['pred_xstart'] - x_start) ** 2))
        eps = self._predict_eps_from_xstart(x_t, t_batch, out['pred_xstart'])
        mse.append(mean_flat((eps - noise) ** 2))
    vb = th.stack(vb, dim=1)
    xstart_mse = th.stack(xstart_mse, dim=1)
    mse = th.stack(mse, dim=1)
    prior_bpd = self._prior_bpd(x_start)
    total_bpd = vb.sum(dim=1) + prior_bpd
    return {'total_bpd': total_bpd, 'prior_bpd': prior_bpd, 'vb': vb,
        'xstart_mse': xstart_mse, 'mse': mse}
",point_e\diffusion\gaussian_diffusion.py
scale_channels,,"def scale_channels(self, x: th.Tensor) ->th.Tensor:
    if self.channel_scales is not None:
        x = x * th.from_numpy(self.channel_scales).to(x).reshape([1, -1, *(
            [1] * (len(x.shape) - 2))])
    if self.channel_biases is not None:
        x = x + th.from_numpy(self.channel_biases).to(x).reshape([1, -1, *(
            [1] * (len(x.shape) - 2))])
    return x
",point_e\diffusion\gaussian_diffusion.py
unscale_channels,,"def unscale_channels(self, x: th.Tensor) ->th.Tensor:
    if self.channel_biases is not None:
        x = x - th.from_numpy(self.channel_biases).to(x).reshape([1, -1, *(
            [1] * (len(x.shape) - 2))])
    if self.channel_scales is not None:
        x = x / th.from_numpy(self.channel_scales).to(x).reshape([1, -1, *(
            [1] * (len(x.shape) - 2))])
    return x
",point_e\diffusion\gaussian_diffusion.py
unscale_out_dict,,"def unscale_out_dict(self, out: Dict[str, Union[th.Tensor, Any]]) ->Dict[
    str, Union[th.Tensor, Any]]:
    return {k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v
        ) for k, v in out.items()}
",point_e\diffusion\gaussian_diffusion.py
__init__,,"def __init__(self, use_timesteps: Iterable[int], **kwargs):
    self.use_timesteps = set(use_timesteps)
    self.timestep_map = []
    self.original_num_steps = len(kwargs['betas'])
    base_diffusion = GaussianDiffusion(**kwargs)
    last_alpha_cumprod = 1.0
    new_betas = []
    for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
        if i in self.use_timesteps:
            new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
            last_alpha_cumprod = alpha_cumprod
            self.timestep_map.append(i)
    kwargs['betas'] = np.array(new_betas)
    super().__init__(**kwargs)
",point_e\diffusion\gaussian_diffusion.py
p_mean_variance,,"def p_mean_variance(self, model, *args, **kwargs):
    return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
",point_e\diffusion\gaussian_diffusion.py
training_losses,,"def training_losses(self, model, *args, **kwargs):
    return super().training_losses(self._wrap_model(model), *args, **kwargs)
",point_e\diffusion\gaussian_diffusion.py
condition_mean,,"def condition_mean(self, cond_fn, *args, **kwargs):
    return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
",point_e\diffusion\gaussian_diffusion.py
condition_score,,"def condition_score(self, cond_fn, *args, **kwargs):
    return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
",point_e\diffusion\gaussian_diffusion.py
_wrap_model,,"def _wrap_model(self, model):
    if isinstance(model, _WrappedModel):
        return model
    return _WrappedModel(model, self.timestep_map, self.original_num_steps)
",point_e\diffusion\gaussian_diffusion.py
__init__,,"def __init__(self, model, timestep_map, original_num_steps):
    self.model = model
    self.timestep_map = timestep_map
    self.original_num_steps = original_num_steps
",point_e\diffusion\gaussian_diffusion.py
__call__,,"def __call__(self, x, ts, **kwargs):
    map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
    new_ts = map_tensor[ts]
    return self.model(x, new_ts, **kwargs)
",point_e\diffusion\gaussian_diffusion.py
karras_sample,,"def karras_sample(*args, **kwargs):
    last = None
    for x in karras_sample_progressive(*args, **kwargs):
        last = x['x']
    return last
",point_e\diffusion\k_diffusion.py
karras_sample_progressive,,"def karras_sample_progressive(diffusion, model, shape, steps, clip_denoised
    =True, progress=False, model_kwargs=None, device=None, sigma_min=0.002,
    sigma_max=80, rho=7.0, sampler='heun', s_churn=0.0, s_tmin=0.0, s_tmax=
    float('inf'), s_noise=1.0, guidance_scale=0.0):
    sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
    x_T = th.randn(*shape, device=device) * sigma_max
    sample_fn = {'heun': sample_heun, 'dpm': sample_dpm, 'ancestral':
        sample_euler_ancestral}[sampler]
    if sampler != 'ancestral':
        sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax,
            s_noise=s_noise)
    else:
        sampler_args = {}
    if isinstance(diffusion, KarrasDenoiser):

        def denoiser(x_t, sigma):
            _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
            if clip_denoised:
                denoised = denoised.clamp(-1, 1)
            return denoised
    elif isinstance(diffusion, GaussianDiffusion):
        model = GaussianToKarrasDenoiser(model, diffusion)

        def denoiser(x_t, sigma):
            _, denoised = model.denoise(x_t, sigma, clip_denoised=
                clip_denoised, model_kwargs=model_kwargs)
            return denoised
    else:
        raise NotImplementedError
    if guidance_scale != 0 and guidance_scale != 1:

        def guided_denoiser(x_t, sigma):
            x_t = th.cat([x_t, x_t], dim=0)
            sigma = th.cat([sigma, sigma], dim=0)
            x_0 = denoiser(x_t, sigma)
            cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
            x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
            return x_0
    else:
        guided_denoiser = denoiser
    for obj in sample_fn(guided_denoiser, x_T, sigmas, progress=progress,
        **sampler_args):
        if isinstance(diffusion, GaussianDiffusion):
            yield diffusion.unscale_out_dict(obj)
        else:
            yield obj
",point_e\diffusion\k_diffusion.py
get_sigmas_karras,Constructs the noise schedule of Karras et al. (2022).,"def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device='cpu'):
    """"""""""""
    ramp = th.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)
",point_e\diffusion\k_diffusion.py
to_d,Converts a denoiser output to a Karras ODE derivative.,"def to_d(x, sigma, denoised):
    """"""""""""
    return (x - denoised) / append_dims(sigma, x.ndim)
",point_e\diffusion\k_diffusion.py
get_ancestral_step,"Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step.","def get_ancestral_step(sigma_from, sigma_to):
    """"""""""""
    sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / 
        sigma_from ** 2) ** 0.5
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up
",point_e\diffusion\k_diffusion.py
sample_euler_ancestral,Ancestral sampling with Euler method steps.,"@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, progress=False):
    """"""""""""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm
        indices = tqdm(indices)
    for i in indices:
        denoised = model(x, sigmas[i] * s_in)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        yield {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i],
            'pred_xstart': denoised}
        d = to_d(x, sigmas[i], denoised)
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + th.randn_like(x) * sigma_up
    yield {'x': x, 'pred_xstart': x}
",point_e\diffusion\k_diffusion.py
sample_heun,Implements Algorithm 2 (Heun steps) from Karras et al. (2022).,"@th.no_grad()
def sample_heun(denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=
    0.0, s_tmax=float('inf'), s_noise=1.0):
    """"""""""""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm
        indices = tqdm(indices)
    for i in indices:
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1
            ) if s_tmin <= sigmas[i] <= s_tmax else 0.0
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        yield {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat,
            'pred_xstart': denoised}
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            x = x + d * dt
        else:
            x_2 = x + d * dt
            denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    yield {'x': x, 'pred_xstart': denoised}
",point_e\diffusion\k_diffusion.py
sample_dpm,A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).,"@th.no_grad()
def sample_dpm(denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=0.0,
    s_tmax=float('inf'), s_noise=1.0):
    """"""""""""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm
        indices = tqdm(indices)
    for i in indices:
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1
            ) if s_tmin <= sigmas[i] <= s_tmax else 0.0
        eps = th.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        yield {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat,
            'denoised': denoised}
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2
            ) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = denoiser(x_2, sigma_mid * s_in)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    yield {'x': x, 'pred_xstart': denoised}
",point_e\diffusion\k_diffusion.py
append_dims,Appends dimensions to the end of a tensor until it has target_dims dimensions.,"def append_dims(x, target_dims):
    """"""""""""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
            )
    return x[(...,) + (None,) * dims_to_append]
",point_e\diffusion\k_diffusion.py
append_zero,,"def append_zero(x):
    return th.cat([x, x.new_zeros([1])])
",point_e\diffusion\k_diffusion.py
__init__,,"def __init__(self, sigma_data: float=0.5):
    self.sigma_data = sigma_data
",point_e\diffusion\k_diffusion.py
get_snr,,"def get_snr(self, sigmas):
    return sigmas ** -2
",point_e\diffusion\k_diffusion.py
get_sigmas,,"def get_sigmas(self, sigmas):
    return sigmas
",point_e\diffusion\k_diffusion.py
get_scalings,,"def get_scalings(self, sigma):
    c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
    c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2
        ) ** 0.5
    c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
    return c_skip, c_out, c_in
",point_e\diffusion\k_diffusion.py
training_losses,,"def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None
    ):
    if model_kwargs is None:
        model_kwargs = {}
    if noise is None:
        noise = th.randn_like(x_start)
    terms = {}
    dims = x_start.ndim
    x_t = x_start + noise * append_dims(sigmas, dims)
    c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)
        ]
    model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
    target = (x_start - c_skip * x_t) / c_out
    terms['mse'] = mean_flat((model_output - target) ** 2)
    terms['xs_mse'] = mean_flat((denoised - x_start) ** 2)
    if 'vb' in terms:
        terms['loss'] = terms['mse'] + terms['vb']
    else:
        terms['loss'] = terms['mse']
    return terms
",point_e\diffusion\k_diffusion.py
denoise,,"def denoise(self, model, x_t, sigmas, **model_kwargs):
    c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.
        get_scalings(sigmas)]
    rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
    denoised = c_out * model_output + c_skip * x_t
    return model_output, denoised
",point_e\diffusion\k_diffusion.py
__init__,,"def __init__(self, model, diffusion):
    from scipy import interpolate
    self.model = model
    self.diffusion = diffusion
    self.alpha_cumprod_to_t = interpolate.interp1d(diffusion.alphas_cumprod,
        np.arange(0, diffusion.num_timesteps))
",point_e\diffusion\k_diffusion.py
sigma_to_t,,"def sigma_to_t(self, sigma):
    alpha_cumprod = 1.0 / (sigma ** 2 + 1)
    if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
        return 0
    elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
        return self.diffusion.num_timesteps - 1
    else:
        return float(self.alpha_cumprod_to_t(alpha_cumprod))
",point_e\diffusion\k_diffusion.py
denoise,,"def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
    t = th.tensor([self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()
        ], dtype=th.long, device=sigmas.device)
    c_in = append_dims(1.0 / (sigmas ** 2 + 1) ** 0.5, x_t.ndim)
    out = self.diffusion.p_mean_variance(self.model, x_t * c_in, t,
        clip_denoised=clip_denoised, model_kwargs=model_kwargs)
    return None, out['pred_xstart']
",point_e\diffusion\k_diffusion.py
__init__,,"def __init__(self, device: torch.device, models: Sequence[nn.Module],
    diffusions: Sequence[GaussianDiffusion], num_points: Sequence[int],
    aux_channels: Sequence[str], model_kwargs_key_filter: Sequence[str]=(
    '*',), guidance_scale: Sequence[float]=(3.0, 3.0), clip_denoised: bool=
    True, use_karras: Sequence[bool]=(True, True), karras_steps: Sequence[
    int]=(64, 64), sigma_min: Sequence[float]=(0.001, 0.001), sigma_max:
    Sequence[float]=(120, 160), s_churn: Sequence[float]=(3, 0)):
    n = len(models)
    assert n > 0
    if n > 1:
        if len(guidance_scale) == 1:
            guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
        if len(use_karras) == 1:
            use_karras = use_karras * n
        if len(karras_steps) == 1:
            karras_steps = karras_steps * n
        if len(sigma_min) == 1:
            sigma_min = sigma_min * n
        if len(sigma_max) == 1:
            sigma_max = sigma_max * n
        if len(s_churn) == 1:
            s_churn = s_churn * n
        if len(model_kwargs_key_filter) == 1:
            model_kwargs_key_filter = model_kwargs_key_filter * n
    if len(model_kwargs_key_filter) == 0:
        model_kwargs_key_filter = ['*'] * n
    assert len(guidance_scale) == n
    assert len(use_karras) == n
    assert len(karras_steps) == n
    assert len(sigma_min) == n
    assert len(sigma_max) == n
    assert len(s_churn) == n
    assert len(model_kwargs_key_filter) == n
    self.device = device
    self.num_points = num_points
    self.aux_channels = aux_channels
    self.model_kwargs_key_filter = model_kwargs_key_filter
    self.guidance_scale = guidance_scale
    self.clip_denoised = clip_denoised
    self.use_karras = use_karras
    self.karras_steps = karras_steps
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.s_churn = s_churn
    self.models = models
    self.diffusions = diffusions
",point_e\diffusion\sampler.py
num_stages,,"@property
def num_stages(self) ->int:
    return len(self.models)
",point_e\diffusion\sampler.py
sample_batch,,"def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]
    ) ->torch.Tensor:
    samples = None
    for x in self.sample_batch_progressive(batch_size, model_kwargs):
        samples = x
    return samples
",point_e\diffusion\sampler.py
sample_batch_progressive,,"def sample_batch_progressive(self, batch_size: int, model_kwargs: Dict[str,
    Any]) ->Iterator[torch.Tensor]:
    samples = None
    for model, diffusion, stage_num_points, stage_guidance_scale, stage_use_karras, stage_karras_steps, stage_sigma_min, stage_sigma_max, stage_s_churn, stage_key_filter in zip(
        self.models, self.diffusions, self.num_points, self.guidance_scale,
        self.use_karras, self.karras_steps, self.sigma_min, self.sigma_max,
        self.s_churn, self.model_kwargs_key_filter):
        stage_model_kwargs = model_kwargs.copy()
        if stage_key_filter != '*':
            use_keys = set(stage_key_filter.split(','))
            stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items
                () if k in use_keys}
        if samples is not None:
            stage_model_kwargs['low_res'] = samples
        if hasattr(model, 'cached_model_kwargs'):
            stage_model_kwargs = model.cached_model_kwargs(batch_size,
                stage_model_kwargs)
        sample_shape = batch_size, 3 + len(self.aux_channels), stage_num_points
        if stage_guidance_scale != 1 and stage_guidance_scale != 0:
            for k, v in stage_model_kwargs.copy().items():
                stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)],
                    dim=0)
        if stage_use_karras:
            samples_it = karras_sample_progressive(diffusion=diffusion,
                model=model, shape=sample_shape, steps=stage_karras_steps,
                clip_denoised=self.clip_denoised, model_kwargs=
                stage_model_kwargs, device=self.device, sigma_min=
                stage_sigma_min, sigma_max=stage_sigma_max, s_churn=
                stage_s_churn, guidance_scale=stage_guidance_scale)
        else:
            internal_batch_size = batch_size
            if stage_guidance_scale:
                model = self._uncond_guide_model(model, stage_guidance_scale)
                internal_batch_size *= 2
            samples_it = diffusion.p_sample_loop_progressive(model, shape=(
                internal_batch_size, *sample_shape[1:]), model_kwargs=
                stage_model_kwargs, device=self.device, clip_denoised=self.
                clip_denoised)
        for x in samples_it:
            samples = x['pred_xstart'][:batch_size]
            if 'low_res' in stage_model_kwargs:
                samples = torch.cat([stage_model_kwargs['low_res'][:len(
                    samples)], samples], dim=-1)
            yield samples
",point_e\diffusion\sampler.py
combine,,"@classmethod
def combine(cls, *samplers: 'PointCloudSampler') ->'PointCloudSampler':
    assert all(x.device == samplers[0].device for x in samplers[1:])
    assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:]
        )
    assert all(x.clip_denoised == samplers[0].clip_denoised for x in
        samplers[1:])
    return cls(device=samplers[0].device, models=[x for y in samplers for x in
        y.models], diffusions=[x for y in samplers for x in y.diffusions],
        num_points=[x for y in samplers for x in y.num_points],
        aux_channels=samplers[0].aux_channels, model_kwargs_key_filter=[x for
        y in samplers for x in y.model_kwargs_key_filter], guidance_scale=[
        x for y in samplers for x in y.guidance_scale], clip_denoised=
        samplers[0].clip_denoised, use_karras=[x for y in samplers for x in
        y.use_karras], karras_steps=[x for y in samplers for x in y.
        karras_steps], sigma_min=[x for y in samplers for x in y.sigma_min],
        sigma_max=[x for y in samplers for x in y.sigma_max], s_churn=[x for
        y in samplers for x in y.s_churn])
",point_e\diffusion\sampler.py
_uncond_guide_model,,"def _uncond_guide_model(self, model: Callable[..., torch.Tensor], scale: float
    ) ->Callable[..., torch.Tensor]:

    def model_fn(x_t, ts, **kwargs):
        half = x_t[:len(x_t) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = model(combined, ts, **kwargs)
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
        half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)
    return model_fn
",point_e\diffusion\sampler.py
split_model_output,,"def split_model_output(self, output: torch.Tensor, rescale_colors: bool=False
    ) ->Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    assert len(self.aux_channels) + 3 == output.shape[1
        ], 'there must be three spatial channels before aux'
    pos, joined_aux = output[:, :3], output[:, 3:]
    aux = {}
    for i, name in enumerate(self.aux_channels):
        v = joined_aux[:, i]
        if name in {'R', 'G', 'B', 'A'}:
            v = v.clamp(0, 255).round()
            if rescale_colors:
                v = v / 255.0
        aux[name] = v
    return pos, aux
",point_e\diffusion\sampler.py
output_to_point_clouds,,"def output_to_point_clouds(self, output: torch.Tensor) ->List[PointCloud]:
    res = []
    for sample in output:
        xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
        res.append(PointCloud(coords=xyz[0].t().cpu().numpy(), channels={k:
            v[0].cpu().numpy() for k, v in aux.items()}))
    return res
",point_e\diffusion\sampler.py
with_options,,"def with_options(self, guidance_scale: float, clip_denoised: bool,
    use_karras: Sequence[bool]=(True, True), karras_steps: Sequence[int]=(
    64, 64), sigma_min: Sequence[float]=(0.001, 0.001), sigma_max: Sequence
    [float]=(120, 160), s_churn: Sequence[float]=(3, 0)) ->'PointCloudSampler':
    return PointCloudSampler(device=self.device, models=self.models,
        diffusions=self.diffusions, num_points=self.num_points,
        aux_channels=self.aux_channels, model_kwargs_key_filter=self.
        model_kwargs_key_filter, guidance_scale=guidance_scale,
        clip_denoised=clip_denoised, use_karras=use_karras, karras_steps=
        karras_steps, sigma_min=sigma_min, sigma_max=sigma_max, s_churn=s_churn
        )
",point_e\diffusion\sampler.py
get_torch_devices,,"def get_torch_devices() ->List[Union[str, torch.device]]:
    if torch.cuda.is_available():
        return [torch.device(f'cuda:{i}') for i in range(torch.cuda.
            device_count())]
    else:
        return ['cpu']
",point_e\evals\feature_extractor.py
normalize_point_clouds,,"def normalize_point_clouds(pc: np.ndarray) ->np.ndarray:
    centroids = np.mean(pc, axis=1, keepdims=True)
    pc = pc - centroids
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=-1, keepdims=True)), axis=1,
        keepdims=True)
    pc = pc / m
    return pc
",point_e\evals\feature_extractor.py
supports_predictions,,"@property
@abstractmethod
def supports_predictions(self) ->bool:
    pass
",point_e\evals\feature_extractor.py
feature_dim,,"@property
@abstractmethod
def feature_dim(self) ->int:
    pass
",point_e\evals\feature_extractor.py
num_classes,,"@property
@abstractmethod
def num_classes(self) ->int:
    pass
",point_e\evals\feature_extractor.py
features_and_preds,"For a stream of point cloud batches, compute feature vectors and class
predictions.

:param point_clouds: a streamer for a sample batch. Typically, arr_0
                     will contain the XYZ coordinates.
:return: a tuple (features, predictions)
         - features: a [B x feature_dim] array of feature vectors.
         - predictions: a [B x num_classes] array of probabilities.","@abstractmethod
def features_and_preds(self, streamer: NpzStreamer) ->Tuple[np.ndarray, np.
    ndarray]:
    """"""""""""
",point_e\evals\feature_extractor.py
__init__,,"def __init__(self, devices: List[Union[str, torch.device]],
    device_batch_size: int=64, cache_dir: Optional[str]=None):
    state_dict = load_checkpoint('pointnet', device=torch.device('cpu'),
        cache_dir=cache_dir)['model_state_dict']
    self.device_batch_size = device_batch_size
    self.devices = devices
    self.models = []
    for device in devices:
        model = get_model(num_class=40, normal_channel=False, width_mult=2)
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
        self.models.append(model)
",point_e\evals\feature_extractor.py
supports_predictions,,"@property
def supports_predictions(self) ->bool:
    return True
",point_e\evals\feature_extractor.py
feature_dim,,"@property
def feature_dim(self) ->int:
    return 256
",point_e\evals\feature_extractor.py
num_classes,,"@property
def num_classes(self) ->int:
    return 40
",point_e\evals\feature_extractor.py
features_and_preds,,"def features_and_preds(self, streamer: NpzStreamer) ->Tuple[np.ndarray, np.
    ndarray]:
    batch_size = self.device_batch_size * len(self.devices)
    point_clouds = (x['arr_0'] for x in streamer.stream(batch_size, ['arr_0']))
    output_features = []
    output_predictions = []
    with ThreadPool(len(self.devices)) as pool:
        for batch in point_clouds:
            batch = normalize_point_clouds(batch)
            batches = []
            for i, device in zip(range(0, len(batch), self.
                device_batch_size), self.devices):
                batches.append(torch.from_numpy(batch[i:i + self.
                    device_batch_size]).permute(0, 2, 1).to(dtype=torch.
                    float32, device=device))

            def compute_features(i_batch):
                i, batch = i_batch
                with torch.no_grad():
                    return self.models[i](batch, features=True)
            for logits, _, features in pool.imap(compute_features,
                enumerate(batches)):
                output_features.append(features.cpu().numpy())
                output_predictions.append(logits.exp().cpu().numpy())
    return np.concatenate(output_features, axis=0), np.concatenate(
        output_predictions, axis=0)
",point_e\evals\feature_extractor.py
compute_statistics,,"def compute_statistics(feats: np.ndarray) ->FIDStatistics:
    mu = np.mean(feats, axis=0)
    sigma = np.cov(feats, rowvar=False)
    return FIDStatistics(mu, sigma)
",point_e\evals\fid_is.py
compute_inception_score,,"def compute_inception_score(preds: np.ndarray, split_size: int=5000) ->float:
    scores = []
    for i in range(0, len(preds), split_size):
        part = preds[i:i + split_size]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))
            )
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    return float(np.mean(scores))
",point_e\evals\fid_is.py
__init__,,"def __init__(self, mu: np.ndarray, sigma: np.ndarray):
    self.mu = mu
    self.sigma = sigma
",point_e\evals\fid_is.py
frechet_distance,Compute the Frechet distance between two sets of statistics.,"def frechet_distance(self, other, eps=1e-06):
    """"""""""""
    mu1, sigma1 = self.mu, self.sigma
    mu2, sigma2 = other.mu, other.sigma
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    assert mu1.shape == mu2.shape, f'Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}'
    assert sigma1.shape == sigma2.shape, f'Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}'
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
            'fid calculation produces singular product; adding %s to diagonal of cov estimates'
             % eps)
        warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=0.001):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real
    tr_covmean = np.trace(covmean)
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2
        ) - 2 * tr_covmean
",point_e\evals\fid_is.py
_npz_paths_and_length,,"def _npz_paths_and_length(glob_path: str) ->Tuple[List[str], Optional[int]]:
    count_match = re.match('^(.*)\\[:([0-9]*)\\]$', glob_path)
    if count_match:
        raw_path = count_match[1]
        max_count = int(count_match[2])
    else:
        raw_path = glob_path
        max_count = None
    paths = sorted(glob.glob(raw_path))
    if not len(paths):
        raise ValueError(f'no paths found matching: {glob_path}')
    return paths, max_count
",point_e\evals\npz_stream.py
open_npz_arrays,,"@contextmanager
def open_npz_arrays(path: str, arr_names: Sequence[str]) ->List[NpzArrayReader
    ]:
    if not len(arr_names):
        yield []
        return
    arr_name = arr_names[0]
    with open_array(path, arr_name) as arr_f:
        version = np.lib.format.read_magic(arr_f)
        header = None
        if version == (1, 0):
            header = np.lib.format.read_array_header_1_0(arr_f)
        elif version == (2, 0):
            header = np.lib.format.read_array_header_2_0(arr_f)
        if header is None:
            reader = MemoryNpzArrayReader.load(path, arr_name)
        else:
            shape, fortran, dtype = header
            if fortran or dtype.hasobject:
                reader = MemoryNpzArrayReader.load(path, arr_name)
            else:
                reader = StreamingNpzArrayReader(arr_f, shape, dtype)
        with open_npz_arrays(path, arr_names[1:]) as next_readers:
            yield [reader] + next_readers
",point_e\evals\npz_stream.py
_read_bytes,"Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886

Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.","def _read_bytes(fp, size, error_template='ran out of data'):
    """"""""""""
    data = bytes()
    while True:
        try:
            r = fp.read(size - len(data))
            data += r
            if len(r) == 0 or len(data) == size:
                break
        except io.BlockingIOError:
            pass
    if len(data) != size:
        msg = 'EOF: reading %s, expected %d bytes got %d'
        raise ValueError(msg % (error_template, size, len(data)))
    else:
        return data
",point_e\evals\npz_stream.py
open_array,,"@contextmanager
def open_array(path: str, arr_name: str):
    with open(path, 'rb') as f:
        with zipfile.ZipFile(f, 'r') as zip_f:
            if f'{arr_name}.npy' not in zip_f.namelist():
                raise ValueError(f'missing {arr_name} in npz file')
            with zip_f.open(f'{arr_name}.npy', 'r') as arr_f:
                yield arr_f
",point_e\evals\npz_stream.py
_dict_batch_size,,"def _dict_batch_size(objs: Dict[str, np.ndarray]) ->int:
    return len(next(iter(objs.values())))
",point_e\evals\npz_stream.py
infos_from_first_file,,"@classmethod
def infos_from_first_file(cls, glob_path: str) ->Dict[str, 'NumpyArrayInfo']:
    paths, _ = _npz_paths_and_length(glob_path)
    return cls.infos_from_file(paths[0])
",point_e\evals\npz_stream.py
infos_from_file,Extract the info of every array in an npz file.,"@classmethod
def infos_from_file(cls, npz_path: str) ->Dict[str, 'NumpyArrayInfo']:
    """"""""""""
    if not os.path.exists(npz_path):
        raise FileNotFoundError(f'batch of samples was not found: {npz_path}')
    results = {}
    with open(npz_path, 'rb') as f:
        with zipfile.ZipFile(f, 'r') as zip_f:
            for name in zip_f.namelist():
                if not name.endswith('.npy'):
                    continue
                key_name = name[:-len('.npy')]
                with zip_f.open(name, 'r') as arr_f:
                    version = np.lib.format.read_magic(arr_f)
                    if version == (1, 0):
                        header = np.lib.format.read_array_header_1_0(arr_f)
                    elif version == (2, 0):
                        header = np.lib.format.read_array_header_2_0(arr_f)
                    else:
                        raise ValueError(
                            f'unknown numpy array version: {version}')
                    shape, _, dtype = header
                    results[key_name] = cls(name=key_name, dtype=dtype,
                        shape=shape)
    return results
",point_e\evals\npz_stream.py
elem_shape,,"@property
def elem_shape(self) ->Tuple[int]:
    return self.shape[1:]
",point_e\evals\npz_stream.py
validate,,"def validate(self):
    if self.name in {'R', 'G', 'B'}:
        if len(self.shape) != 2:
            raise ValueError(
                f""expecting exactly 2-D shape for '{self.name}' but got: {self.shape}""
                )
    elif self.name == 'arr_0':
        if len(self.shape) < 2:
            raise ValueError(
                f'expecting at least 2-D shape but got: {self.shape}')
        elif len(self.shape) == 3:
            if not np.issubdtype(self.dtype, np.floating):
                raise ValueError(
                    f'invalid dtype for audio batch: {self.dtype} (expected float)'
                    )
        elif self.dtype != np.uint8:
            raise ValueError(
                f'invalid dtype for image batch: {self.dtype} (expected uint8)'
                )
",point_e\evals\npz_stream.py
__init__,,"def __init__(self, glob_path: str):
    self.paths, self.trunc_length = _npz_paths_and_length(glob_path)
    self.infos = NumpyArrayInfo.infos_from_file(self.paths[0])
",point_e\evals\npz_stream.py
keys,,"def keys(self) ->List[str]:
    return list(self.infos.keys())
",point_e\evals\npz_stream.py
stream,,"def stream(self, batch_size: int, keys: Sequence[str]) ->Iterator[Dict[str,
    np.ndarray]]:
    cur_batch = None
    num_remaining = self.trunc_length
    for path in self.paths:
        if num_remaining is not None and num_remaining <= 0:
            break
        with open_npz_arrays(path, keys) as readers:
            combined_reader = CombinedReader(keys, readers)
            while num_remaining is None or num_remaining > 0:
                read_bs = batch_size
                if cur_batch is not None:
                    read_bs -= _dict_batch_size(cur_batch)
                if num_remaining is not None:
                    read_bs = min(read_bs, num_remaining)
                batch = combined_reader.read_batch(read_bs)
                if batch is None:
                    break
                if num_remaining is not None:
                    num_remaining -= _dict_batch_size(batch)
                if cur_batch is None:
                    cur_batch = batch
                else:
                    cur_batch = {k: np.concatenate([cur_batch[k], v], axis=
                        0) for k, v in batch.items()}
                if _dict_batch_size(cur_batch) == batch_size:
                    yield cur_batch
                    cur_batch = None
    if cur_batch is not None:
        yield cur_batch
",point_e\evals\npz_stream.py
read_batch,,"@abstractmethod
def read_batch(self, batch_size: int) ->Optional[np.ndarray]:
    pass
",point_e\evals\npz_stream.py
__init__,,"def __init__(self, arr_f, shape, dtype):
    self.arr_f = arr_f
    self.shape = shape
    self.dtype = dtype
    self.idx = 0
",point_e\evals\npz_stream.py
read_batch,,"def read_batch(self, batch_size: int) ->Optional[np.ndarray]:
    if self.idx >= self.shape[0]:
        return None
    bs = min(batch_size, self.shape[0] - self.idx)
    self.idx += bs
    if self.dtype.itemsize == 0:
        return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
    read_count = bs * np.prod(self.shape[1:])
    read_size = int(read_count * self.dtype.itemsize)
    data = _read_bytes(self.arr_f, read_size, 'array data')
    return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
",point_e\evals\npz_stream.py
__init__,,"def __init__(self, arr):
    self.arr = arr
    self.idx = 0
",point_e\evals\npz_stream.py
load,,"@classmethod
def load(cls, path: str, arr_name: str):
    with open(path, 'rb') as f:
        arr = np.load(f)[arr_name]
    return cls(arr)
",point_e\evals\npz_stream.py
read_batch,,"def read_batch(self, batch_size: int) ->Optional[np.ndarray]:
    if self.idx >= self.arr.shape[0]:
        return None
    res = self.arr[self.idx:self.idx + batch_size]
    self.idx += batch_size
    return res
",point_e\evals\npz_stream.py
__init__,,"def __init__(self, keys: List[str], readers: List[NpzArrayReader]):
    self.keys = keys
    self.readers = readers
",point_e\evals\npz_stream.py
read_batch,,"def read_batch(self, batch_size: int) ->Optional[Dict[str, np.ndarray]]:
    batches = [r.read_batch(batch_size) for r in self.readers]
    any_none = any(x is None for x in batches)
    all_none = all(x is None for x in batches)
    if any_none != all_none:
        raise RuntimeError('different keys had different numbers of elements')
    if any_none:
        return None
    if any(len(x) != len(batches[0]) for x in batches):
        raise RuntimeError('different keys had different numbers of elements')
    return dict(zip(self.keys, batches))
",point_e\evals\npz_stream.py
__init__,,"def __init__(self, num_class, normal_channel=True, width_mult=1):
    super(get_model, self).__init__()
    self.width_mult = width_mult
    in_channel = 6 if normal_channel else 3
    self.normal_channel = normal_channel
    self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32,
        in_channel=in_channel, mlp=[64 * width_mult, 64 * width_mult, 128 *
        width_mult], group_all=False)
    self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64,
        in_channel=128 * width_mult + 3, mlp=[128 * width_mult, 128 *
        width_mult, 256 * width_mult], group_all=False)
    self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=
        None, in_channel=256 * width_mult + 3, mlp=[256 * width_mult, 512 *
        width_mult, 1024 * width_mult], group_all=True)
    self.fc1 = nn.Linear(1024 * width_mult, 512 * width_mult)
    self.bn1 = nn.BatchNorm1d(512 * width_mult)
    self.drop1 = nn.Dropout(0.4)
    self.fc2 = nn.Linear(512 * width_mult, 256 * width_mult)
    self.bn2 = nn.BatchNorm1d(256 * width_mult)
    self.drop2 = nn.Dropout(0.4)
    self.fc3 = nn.Linear(256 * width_mult, num_class)
",point_e\evals\pointnet2_cls_ssg.py
forward,,"def forward(self, xyz, features=False):
    B, _, _ = xyz.shape
    if self.normal_channel:
        norm = xyz[:, 3:, :]
        xyz = xyz[:, :3, :]
    else:
        norm = None
    l1_xyz, l1_points = self.sa1(xyz, norm)
    l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
    l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
    x = l3_points.view(B, 1024 * self.width_mult)
    x = self.drop1(F.relu(self.bn1(self.fc1(x))))
    result_features = self.bn2(self.fc2(x))
    x = self.drop2(F.relu(result_features))
    x = self.fc3(x)
    x = F.log_softmax(x, -1)
    if features:
        return x, l3_points, result_features
    else:
        return x, l3_points
",point_e\evals\pointnet2_cls_ssg.py
__init__,,"def __init__(self):
    super(get_loss, self).__init__()
",point_e\evals\pointnet2_cls_ssg.py
forward,,"def forward(self, pred, target, trans_feat):
    total_loss = F.nll_loss(pred, target)
    return total_loss
",point_e\evals\pointnet2_cls_ssg.py
timeit,,"def timeit(tag, t):
    print('{}: {}s'.format(tag, time() - t))
    return time()
",point_e\evals\pointnet2_utils.py
pc_normalize,,"def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc
",point_e\evals\pointnet2_utils.py
square_distance,"Calculate Euclid distance between each two points.

src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
     = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

Input:
    src: source points, [B, N, C]
    dst: target points, [B, M, C]
Output:
    dist: per-point square distance, [B, N, M]","def square_distance(src, dst):
    """"""""""""
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist
",point_e\evals\pointnet2_utils.py
index_points,"Input:
    points: input points data, [B, N, C]
    idx: sample index data, [B, S]
Return:
    new_points:, indexed points data, [B, S, C]","def index_points(points, idx):
    """"""""""""
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(
        view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points
",point_e\evals\pointnet2_utils.py
farthest_point_sample,"Input:
    xyz: pointcloud data, [B, N, 3]
    npoint: number of samples
Return:
    centroids: sampled pointcloud index, [B, npoint]","def farthest_point_sample(xyz, npoint, deterministic=False):
    """"""""""""
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 10000000000.0
    if deterministic:
        farthest = torch.arange(0, B, dtype=torch.long).to(device)
    else:
        farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids
",point_e\evals\pointnet2_utils.py
query_ball_point,"Input:
    radius: local region radius
    nsample: max sample number in local region
    xyz: all points, [B, N, 3]
    new_xyz: query points, [B, S, 3]
Return:
    group_idx: grouped points index, [B, S, nsample]","def query_ball_point(radius, nsample, xyz, new_xyz):
    """"""""""""
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N
        ).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx
",point_e\evals\pointnet2_utils.py
sample_and_group,"Input:
    npoint:
    radius:
    nsample:
    xyz: input points position data, [B, N, 3]
    points: input points data, [B, N, D]
Return:
    new_xyz: sampled points position data, [B, npoint, nsample, 3]
    new_points: sampled points data, [B, npoint, nsample, 3+D]","def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False,
    deterministic=False):
    """"""""""""
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic)
    new_xyz = index_points(xyz, fps_idx)
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx)
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points
",point_e\evals\pointnet2_utils.py
sample_and_group_all,"Input:
    xyz: input points position data, [B, N, 3]
    points: input points data, [B, N, D]
Return:
    new_xyz: sampled points position data, [B, 1, 3]
    new_points: sampled points data, [B, 1, N, 3+D]","def sample_and_group_all(xyz, points):
    """"""""""""
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points
",point_e\evals\pointnet2_utils.py
__init__,,"def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
    super(PointNetSetAbstraction, self).__init__()
    self.npoint = npoint
    self.radius = radius
    self.nsample = nsample
    self.mlp_convs = nn.ModuleList()
    self.mlp_bns = nn.ModuleList()
    last_channel = in_channel
    for out_channel in mlp:
        self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
        self.mlp_bns.append(nn.BatchNorm2d(out_channel))
        last_channel = out_channel
    self.group_all = group_all
",point_e\evals\pointnet2_utils.py
forward,"Input:
    xyz: input points position data, [B, C, N]
    points: input points data, [B, D, N]
Return:
    new_xyz: sampled points position data, [B, C, S]
    new_points_concat: sample points feature data, [B, D', S]","def forward(self, xyz, points):
    """"""""""""
    xyz = xyz.permute(0, 2, 1)
    if points is not None:
        points = points.permute(0, 2, 1)
    if self.group_all:
        new_xyz, new_points = sample_and_group_all(xyz, points)
    else:
        new_xyz, new_points = sample_and_group(self.npoint, self.radius,
            self.nsample, xyz, points, deterministic=not self.training)
    new_points = new_points.permute(0, 3, 2, 1)
    for i, conv in enumerate(self.mlp_convs):
        bn = self.mlp_bns[i]
        new_points = F.relu(bn(conv(new_points)))
    new_points = torch.max(new_points, 2)[0]
    new_xyz = new_xyz.permute(0, 2, 1)
    return new_xyz, new_points
",point_e\evals\pointnet2_utils.py
__init__,,"def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
    super(PointNetSetAbstractionMsg, self).__init__()
    self.npoint = npoint
    self.radius_list = radius_list
    self.nsample_list = nsample_list
    self.conv_blocks = nn.ModuleList()
    self.bn_blocks = nn.ModuleList()
    for i in range(len(mlp_list)):
        convs = nn.ModuleList()
        bns = nn.ModuleList()
        last_channel = in_channel + 3
        for out_channel in mlp_list[i]:
            convs.append(nn.Conv2d(last_channel, out_channel, 1))
            bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.conv_blocks.append(convs)
        self.bn_blocks.append(bns)
",point_e\evals\pointnet2_utils.py
forward,"Input:
    xyz: input points position data, [B, C, N]
    points: input points data, [B, D, N]
Return:
    new_xyz: sampled points position data, [B, C, S]
    new_points_concat: sample points feature data, [B, D', S]","def forward(self, xyz, points):
    """"""""""""
    xyz = xyz.permute(0, 2, 1)
    if points is not None:
        points = points.permute(0, 2, 1)
    B, N, C = xyz.shape
    S = self.npoint
    new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic
        =not self.training))
    new_points_list = []
    for i, radius in enumerate(self.radius_list):
        K = self.nsample_list[i]
        group_idx = query_ball_point(radius, K, xyz, new_xyz)
        grouped_xyz = index_points(xyz, group_idx)
        grouped_xyz -= new_xyz.view(B, S, 1, C)
        if points is not None:
            grouped_points = index_points(points, group_idx)
            grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
        else:
            grouped_points = grouped_xyz
        grouped_points = grouped_points.permute(0, 3, 2, 1)
        for j in range(len(self.conv_blocks[i])):
            conv = self.conv_blocks[i][j]
            bn = self.bn_blocks[i][j]
            grouped_points = F.relu(bn(conv(grouped_points)))
        new_points = torch.max(grouped_points, 2)[0]
        new_points_list.append(new_points)
    new_xyz = new_xyz.permute(0, 2, 1)
    new_points_concat = torch.cat(new_points_list, dim=1)
    return new_xyz, new_points_concat
",point_e\evals\pointnet2_utils.py
__init__,,"def __init__(self, in_channel, mlp):
    super(PointNetFeaturePropagation, self).__init__()
    self.mlp_convs = nn.ModuleList()
    self.mlp_bns = nn.ModuleList()
    last_channel = in_channel
    for out_channel in mlp:
        self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
        self.mlp_bns.append(nn.BatchNorm1d(out_channel))
        last_channel = out_channel
",point_e\evals\pointnet2_utils.py
forward,"Input:
    xyz1: input points position data, [B, C, N]
    xyz2: sampled input points position data, [B, C, S]
    points1: input points data, [B, D, N]
    points2: input points data, [B, D, S]
Return:
    new_points: upsampled points data, [B, D', N]","def forward(self, xyz1, xyz2, points1, points2):
    """"""""""""
    xyz1 = xyz1.permute(0, 2, 1)
    xyz2 = xyz2.permute(0, 2, 1)
    points2 = points2.permute(0, 2, 1)
    B, N, C = xyz1.shape
    _, S, _ = xyz2.shape
    if S == 1:
        interpolated_points = points2.repeat(1, N, 1)
    else:
        dists = square_distance(xyz1, xyz2)
        dists, idx = dists.sort(dim=-1)
        dists, idx = dists[:, :, :3], idx[:, :, :3]
        dist_recip = 1.0 / (dists + 1e-08)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm
        interpolated_points = torch.sum(index_points(points2, idx) * weight
            .view(B, N, 3, 1), dim=2)
    if points1 is not None:
        points1 = points1.permute(0, 2, 1)
        new_points = torch.cat([points1, interpolated_points], dim=-1)
    else:
        new_points = interpolated_points
    new_points = new_points.permute(0, 2, 1)
    for i, conv in enumerate(self.mlp_convs):
        bn = self.mlp_bns[i]
        new_points = F.relu(bn(conv(new_points)))
    return new_points
",point_e\evals\pointnet2_utils.py
clear_scene,,"def clear_scene():
    bpy.ops.object.select_all(action='SELECT')
    bpy.ops.object.delete()
",point_e\evals\scripts\blender_script.py
clear_lights,,"def clear_lights():
    bpy.ops.object.select_all(action='DESELECT')
    for obj in bpy.context.scene.objects.values():
        if isinstance(obj.data, bpy.types.Light):
            obj.select_set(True)
    bpy.ops.object.delete()
",point_e\evals\scripts\blender_script.py
import_model,,"def import_model(path):
    clear_scene()
    _, ext = os.path.splitext(path)
    ext = ext.lower()
    if ext == '.obj':
        bpy.ops.import_scene.obj(filepath=path)
    elif ext in ['.glb', '.gltf']:
        bpy.ops.import_scene.gltf(filepath=path)
    elif ext == '.stl':
        bpy.ops.import_mesh.stl(filepath=path)
    elif ext == '.fbx':
        bpy.ops.import_scene.fbx(filepath=path)
    elif ext == '.dae':
        bpy.ops.wm.collada_import(filepath=path)
    elif ext == '.ply':
        bpy.ops.import_mesh.ply(filepath=path)
    else:
        raise RuntimeError(f'unexpected extension: {ext}')
",point_e\evals\scripts\blender_script.py
scene_root_objects,,"def scene_root_objects():
    for obj in bpy.context.scene.objects.values():
        if not obj.parent:
            yield obj
",point_e\evals\scripts\blender_script.py
scene_bbox,,"def scene_bbox(single_obj=None, ignore_matrix=False):
    bbox_min = (math.inf,) * 3
    bbox_max = (-math.inf,) * 3
    found = False
    for obj in (scene_meshes() if single_obj is None else [single_obj]):
        found = True
        for coord in obj.bound_box:
            coord = Vector(coord)
            if not ignore_matrix:
                coord = obj.matrix_world @ coord
            bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
            bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
    if not found:
        raise RuntimeError('no objects in scene to compute bounding box for')
    return Vector(bbox_min), Vector(bbox_max)
",point_e\evals\scripts\blender_script.py
scene_meshes,,"def scene_meshes():
    for obj in bpy.context.scene.objects.values():
        if isinstance(obj.data, bpy.types.Mesh):
            yield obj
",point_e\evals\scripts\blender_script.py
normalize_scene,,"def normalize_scene():
    bbox_min, bbox_max = scene_bbox()
    scale = 1 / max(bbox_max - bbox_min)
    for obj in scene_root_objects():
        obj.scale = obj.scale * scale
    bpy.context.view_layer.update()
    bbox_min, bbox_max = scene_bbox()
    offset = -(bbox_min + bbox_max) / 2
    for obj in scene_root_objects():
        obj.matrix_world.translation += offset
    bpy.ops.object.select_all(action='DESELECT')
",point_e\evals\scripts\blender_script.py
create_camera,,"def create_camera():
    camera_data = bpy.data.cameras.new(name='Camera')
    camera_object = bpy.data.objects.new('Camera', camera_data)
    bpy.context.scene.collection.objects.link(camera_object)
    bpy.context.scene.camera = camera_object
",point_e\evals\scripts\blender_script.py
set_camera,,"def set_camera(direction, camera_dist=2.0):
    camera_pos = -camera_dist * direction
    bpy.context.scene.camera.location = camera_pos
    rot_quat = direction.to_track_quat('-Z', 'Y')
    bpy.context.scene.camera.rotation_euler = rot_quat.to_euler()
    bpy.context.view_layer.update()
",point_e\evals\scripts\blender_script.py
randomize_camera,,"def randomize_camera(camera_dist=2.0):
    direction = random_unit_vector()
    set_camera(direction, camera_dist=camera_dist)
",point_e\evals\scripts\blender_script.py
pan_camera,,"def pan_camera(time, axis='Z', camera_dist=2.0, elevation=-0.1):
    angle = time * math.pi * 2
    direction = [-math.cos(angle), -math.sin(angle), -elevation]
    assert axis in ['X', 'Y', 'Z']
    if axis == 'X':
        direction = [direction[2], *direction[:2]]
    elif axis == 'Y':
        direction = [direction[0], -elevation, direction[1]]
    direction = Vector(direction).normalized()
    set_camera(direction, camera_dist=camera_dist)
",point_e\evals\scripts\blender_script.py
place_camera,,"def place_camera(time, camera_pose_mode='random', camera_dist_min=2.0,
    camera_dist_max=2.0):
    camera_dist = random.uniform(camera_dist_min, camera_dist_max)
    if camera_pose_mode == 'random':
        randomize_camera(camera_dist=camera_dist)
    elif camera_pose_mode == 'z-circular':
        pan_camera(time, axis='Z', camera_dist=camera_dist)
    elif camera_pose_mode == 'z-circular-elevated':
        pan_camera(time, axis='Z', camera_dist=camera_dist, elevation=
            0.2617993878)
    else:
        raise ValueError(f'Unknown camera pose mode: {camera_pose_mode}')
",point_e\evals\scripts\blender_script.py
create_light,,"def create_light(location, energy=1.0, angle=0.5 * math.pi / 180):
    light_data = bpy.data.lights.new(name='Light', type='SUN')
    light_data.energy = energy
    light_data.angle = angle
    light_object = bpy.data.objects.new(name='Light', object_data=light_data)
    direction = -location
    rot_quat = direction.to_track_quat('-Z', 'Y')
    light_object.rotation_euler = rot_quat.to_euler()
    bpy.context.view_layer.update()
    bpy.context.collection.objects.link(light_object)
    light_object.location = location
",point_e\evals\scripts\blender_script.py
create_random_lights,,"def create_random_lights(count=4, distance=2.0, energy=1.5):
    clear_lights()
    for _ in range(count):
        create_light(random_unit_vector() * distance, energy=energy)
",point_e\evals\scripts\blender_script.py
create_camera_light,,"def create_camera_light():
    clear_lights()
    create_light(bpy.context.scene.camera.location, energy=5.0)
",point_e\evals\scripts\blender_script.py
create_uniform_light,,"def create_uniform_light(backend):
    clear_lights()
    pos = Vector(UNIFORM_LIGHT_DIRECTION)
    angle = 0.0092 if backend == 'CYCLES' else math.pi
    create_light(pos, energy=5.0, angle=angle)
    create_light(-pos, energy=5.0, angle=angle)
",point_e\evals\scripts\blender_script.py
create_vertex_color_shaders,,"def create_vertex_color_shaders():
    for obj in bpy.context.scene.objects.values():
        if not isinstance(obj.data, bpy.types.Mesh):
            continue
        if len(obj.data.materials):
            continue
        color_keys = (obj.data.vertex_colors or {}).keys()
        if not len(color_keys):
            continue
        mat = bpy.data.materials.new(name='VertexColored')
        mat.use_nodes = True
        bsdf_node = None
        for node in mat.node_tree.nodes:
            if node.type == 'BSDF_PRINCIPLED':
                bsdf_node = node
        assert bsdf_node is not None, 'material has no Principled BSDF node to modify'
        socket_map = {}
        for input in bsdf_node.inputs:
            socket_map[input.name] = input
        socket_map['Specular'].default_value = 0.0
        socket_map['Roughness'].default_value = 1.0
        v_color = mat.node_tree.nodes.new('ShaderNodeVertexColor')
        v_color.layer_name = color_keys[0]
        mat.node_tree.links.new(v_color.outputs[0], socket_map['Base Color'])
        obj.data.materials.append(mat)
",point_e\evals\scripts\blender_script.py
create_default_materials,,"def create_default_materials():
    for obj in bpy.context.scene.objects.values():
        if isinstance(obj.data, bpy.types.Mesh):
            if not len(obj.data.materials):
                mat = bpy.data.materials.new(name='DefaultMaterial')
                mat.use_nodes = True
                obj.data.materials.append(mat)
",point_e\evals\scripts\blender_script.py
find_materials,,"def find_materials():
    all_materials = set()
    for obj in bpy.context.scene.objects.values():
        if not isinstance(obj.data, bpy.types.Mesh):
            continue
        for mat in obj.data.materials:
            all_materials.add(mat)
    return all_materials
",point_e\evals\scripts\blender_script.py
get_socket_value,,"def get_socket_value(tree, socket):
    default = socket.default_value
    if not isinstance(default, float):
        default = list(default)
    for link in tree.links:
        if link.to_socket == socket:
            return link.from_socket, default
    return None, default
",point_e\evals\scripts\blender_script.py
clear_socket_input,,"def clear_socket_input(tree, socket):
    for link in list(tree.links):
        if link.to_socket == socket:
            tree.links.remove(link)
",point_e\evals\scripts\blender_script.py
set_socket_value,,"def set_socket_value(tree, socket, socket_and_default):
    clear_socket_input(tree, socket)
    old_source_socket, default = socket_and_default
    if isinstance(default, float) and not isinstance(socket.default_value,
        float):
        socket.default_value = [default] * 3 + [1.0]
    else:
        socket.default_value = default
    if old_source_socket is not None:
        tree.links.new(old_source_socket, socket)
",point_e\evals\scripts\blender_script.py
setup_nodes,,"def setup_nodes(output_path, capturing_material_alpha: bool=False):
    tree = bpy.context.scene.node_tree
    links = tree.links
    for node in tree.nodes:
        tree.nodes.remove(node)

    def node_op(op: str, *args, clamp=False):
        node = tree.nodes.new(type='CompositorNodeMath')
        node.operation = op
        if clamp:
            node.use_clamp = True
        for i, arg in enumerate(args):
            if isinstance(arg, (int, float)):
                node.inputs[i].default_value = arg
            else:
                links.new(arg, node.inputs[i])
        return node.outputs[0]

    def node_clamp(x, maximum=1.0):
        return node_op('MINIMUM', x, maximum)

    def node_mul(x, y, **kwargs):
        return node_op('MULTIPLY', x, y, **kwargs)
    input_node = tree.nodes.new(type='CompositorNodeRLayers')
    input_node.scene = bpy.context.scene
    input_sockets = {}
    for output in input_node.outputs:
        input_sockets[output.name] = output
    if capturing_material_alpha:
        color_socket = input_sockets['Image']
    else:
        raw_color_socket = input_sockets['Image']
        color_node = tree.nodes.new(type='CompositorNodeConvertColorSpace')
        color_node.from_color_space = 'Linear'
        color_node.to_color_space = 'sRGB'
        tree.links.new(raw_color_socket, color_node.inputs[0])
        color_socket = color_node.outputs[0]
    split_node = tree.nodes.new(type='CompositorNodeSepRGBA')
    tree.links.new(color_socket, split_node.inputs[0])
    for i, channel in (enumerate('rgba') if not capturing_material_alpha else
        [(0, 'MatAlpha')]):
        output_node = tree.nodes.new(type='CompositorNodeOutputFile')
        output_node.base_path = f'{output_path}_{channel}'
        links.new(split_node.outputs[i], output_node.inputs[0])
    if capturing_material_alpha:
        return
    depth_out = node_clamp(node_mul(input_sockets['Depth'], 1 / MAX_DEPTH))
    output_node = tree.nodes.new(type='CompositorNodeOutputFile')
    output_node.base_path = f'{output_path}_depth'
    links.new(depth_out, output_node.inputs[0])
",point_e\evals\scripts\blender_script.py
render_scene,,"def render_scene(output_path, fast_mode: bool):
    use_workbench = bpy.context.scene.render.engine == 'BLENDER_WORKBENCH'
    if use_workbench:
        bpy.context.scene.render.engine = 'BLENDER_EEVEE'
        bpy.context.scene.eevee.taa_render_samples = 1
    if fast_mode:
        if bpy.context.scene.render.engine == 'BLENDER_EEVEE':
            bpy.context.scene.eevee.taa_render_samples = 1
        elif bpy.context.scene.render.engine == 'CYCLES':
            bpy.context.scene.cycles.samples = 256
    elif bpy.context.scene.render.engine == 'CYCLES':
        bpy.context.scene.cycles.time_limit = 40
    bpy.context.view_layer.update()
    bpy.context.scene.use_nodes = True
    bpy.context.scene.view_layers['ViewLayer'].use_pass_z = True
    bpy.context.scene.view_settings.view_transform = 'Raw'
    bpy.context.scene.render.film_transparent = True
    bpy.context.scene.render.resolution_x = 512
    bpy.context.scene.render.resolution_y = 512
    bpy.context.scene.render.image_settings.file_format = 'PNG'
    bpy.context.scene.render.image_settings.color_mode = 'BW'
    bpy.context.scene.render.image_settings.color_depth = '16'
    bpy.context.scene.render.filepath = output_path
    setup_nodes(output_path)
    bpy.ops.render.render(write_still=True)
    for channel_name in ['r', 'g', 'b', 'a', 'depth']:
        sub_dir = f'{output_path}_{channel_name}'
        image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])
        name, ext = os.path.splitext(output_path)
        if channel_name == 'depth' or not use_workbench:
            os.rename(image_path, f'{name}_{channel_name}{ext}')
        else:
            os.remove(image_path)
        os.removedirs(sub_dir)
    if use_workbench:
        bpy.context.scene.use_nodes = False
        bpy.context.scene.render.engine = 'BLENDER_WORKBENCH'
        bpy.context.scene.render.image_settings.color_mode = 'RGBA'
        bpy.context.scene.render.image_settings.color_depth = '8'
        bpy.context.scene.display.shading.color_type = 'TEXTURE'
        bpy.context.scene.display.shading.light = 'FLAT'
        if fast_mode:
            bpy.context.scene.display.render_aa = 'FXAA'
        os.remove(output_path)
        bpy.ops.render.render(write_still=True)
        bpy.context.scene.render.image_settings.color_mode = 'BW'
        bpy.context.scene.render.image_settings.color_depth = '16'
",point_e\evals\scripts\blender_script.py
scene_fov,,"def scene_fov():
    x_fov = bpy.context.scene.camera.data.angle_x
    y_fov = bpy.context.scene.camera.data.angle_y
    width = bpy.context.scene.render.resolution_x
    height = bpy.context.scene.render.resolution_y
    if bpy.context.scene.camera.data.angle == x_fov:
        y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width)
    else:
        x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height)
    return x_fov, y_fov
",point_e\evals\scripts\blender_script.py
write_camera_metadata,,"def write_camera_metadata(path):
    x_fov, y_fov = scene_fov()
    bbox_min, bbox_max = scene_bbox()
    matrix = bpy.context.scene.camera.matrix_world
    with open(path, 'w') as f:
        json.dump(dict(format_version=FORMAT_VERSION, max_depth=MAX_DEPTH,
            bbox=[list(bbox_min), list(bbox_max)], origin=list(matrix.col[3
            ])[:3], x_fov=x_fov, y_fov=y_fov, x=list(matrix.col[0])[:3], y=
            list(-matrix.col[1])[:3], z=list(-matrix.col[2])[:3]), f)
",point_e\evals\scripts\blender_script.py
save_rendering_dataset,,"def save_rendering_dataset(input_path: str, output_path: str, num_images:
    int, backend: str, light_mode: str, camera_pose: str, camera_dist_min:
    float, camera_dist_max: float, fast_mode: bool):
    assert light_mode in ['random', 'uniform', 'camera']
    assert camera_pose in ['random', 'z-circular', 'z-circular-elevated']
    import_model(input_path)
    bpy.context.scene.render.engine = backend
    normalize_scene()
    if light_mode == 'random':
        create_random_lights()
    elif light_mode == 'uniform':
        create_uniform_light(backend)
    create_camera()
    create_vertex_color_shaders()
    for i in range(num_images):
        t = i / max(num_images - 1, 1)
        place_camera(t, camera_pose_mode=camera_pose, camera_dist_min=
            camera_dist_min, camera_dist_max=camera_dist_max)
        if light_mode == 'camera':
            create_camera_light()
        render_scene(os.path.join(output_path, f'{i:05}.png'), fast_mode=
            fast_mode)
        write_camera_metadata(os.path.join(output_path, f'{i:05}.json'))
    with open(os.path.join(output_path, 'info.json'), 'w') as f:
        info = dict(backend=backend, light_mode=light_mode, fast_mode=
            fast_mode, format_version=FORMAT_VERSION, channels=['R', 'G',
            'B', 'A', 'D'], scale=0.5)
        json.dump(info, f)
",point_e\evals\scripts\blender_script.py
main,,"def main():
    try:
        dash_index = sys.argv.index('--')
    except ValueError as exc:
        raise ValueError(""arguments must be preceded by '--'"") from exc
    raw_args = sys.argv[dash_index + 1:]
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', required=True, type=str)
    parser.add_argument('--output_path', required=True, type=str)
    parser.add_argument('--num_images', type=int, default=20)
    parser.add_argument('--backend', type=str, default='BLENDER_EEVEE')
    parser.add_argument('--light_mode', type=str, default='uniform')
    parser.add_argument('--camera_pose', type=str, default='random')
    parser.add_argument('--camera_dist_min', type=float, default=2.0)
    parser.add_argument('--camera_dist_max', type=float, default=2.0)
    parser.add_argument('--fast_mode', action='store_true')
    args = parser.parse_args(raw_args)
    save_rendering_dataset(input_path=args.input_path, output_path=args.
        output_path, num_images=args.num_images, backend=args.backend,
        light_mode=args.light_mode, camera_pose=args.camera_pose,
        camera_dist_min=args.camera_dist_min, camera_dist_max=args.
        camera_dist_max, fast_mode=args.fast_mode)
",point_e\evals\scripts\blender_script.py
main,,"def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cache_dir', type=str, default=None)
    parser.add_argument('batch_1', type=str)
    parser.add_argument('batch_2', type=str)
    args = parser.parse_args()
    print('creating classifier...')
    clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.
        cache_dir)
    print('computing first batch activations')
    features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1))
    stats_1 = compute_statistics(features_1)
    del features_1
    features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2))
    stats_2 = compute_statistics(features_2)
    del features_2
    print(f'P-FID: {stats_1.frechet_distance(stats_2)}')
",point_e\evals\scripts\evaluate_pfid.py
main,,"def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cache_dir', type=str, default=None)
    parser.add_argument('batch', type=str)
    args = parser.parse_args()
    print('creating classifier...')
    clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.
        cache_dir)
    print('computing batch predictions')
    _, preds = clf.features_and_preds(NpzStreamer(args.batch))
    print(f'P-IS: {compute_inception_score(preds)}')
",point_e\evals\scripts\evaluate_pis.py
checkpoint,"Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
               explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.","def checkpoint(func: Callable[..., Union[torch.Tensor, Sequence[torch.
    Tensor]]], inputs: Sequence[torch.Tensor], params: Iterable[torch.
    Tensor], flag: bool):
    """"""""""""
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)
",point_e\models\checkpoint.py
forward,,"@staticmethod
def forward(ctx, run_function, length, *args):
    ctx.run_function = run_function
    ctx.input_tensors = list(args[:length])
    ctx.input_params = list(args[length:])
    with torch.no_grad():
        output_tensors = ctx.run_function(*ctx.input_tensors)
    return output_tensors
",point_e\models\checkpoint.py
backward,,"@staticmethod
def backward(ctx, *output_grads):
    ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.
        input_tensors]
    with torch.enable_grad():
        shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
        output_tensors = ctx.run_function(*shallow_copies)
    input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors +
        ctx.input_params, output_grads, allow_unused=True)
    del ctx.input_tensors
    del ctx.input_params
    del output_tensors
    return (None, None) + input_grads
",point_e\models\checkpoint.py
model_from_config,,"def model_from_config(config: Dict[str, Any], device: torch.device
    ) ->nn.Module:
    config = config.copy()
    name = config.pop('name')
    if name == 'PointDiffusionTransformer':
        return PointDiffusionTransformer(device=device, dtype=torch.float32,
            **config)
    elif name == 'CLIPImagePointDiffusionTransformer':
        return CLIPImagePointDiffusionTransformer(device=device, dtype=
            torch.float32, **config)
    elif name == 'CLIPImageGridPointDiffusionTransformer':
        return CLIPImageGridPointDiffusionTransformer(device=device, dtype=
            torch.float32, **config)
    elif name == 'UpsamplePointDiffusionTransformer':
        return UpsamplePointDiffusionTransformer(device=device, dtype=torch
            .float32, **config)
    elif name == 'CLIPImageGridUpsamplePointDiffusionTransformer':
        return CLIPImageGridUpsamplePointDiffusionTransformer(device=device,
            dtype=torch.float32, **config)
    elif name == 'CrossAttentionPointCloudSDFModel':
        return CrossAttentionPointCloudSDFModel(device=device, dtype=torch.
            float32, **config)
    raise ValueError(f'unknown model name: {name}')
",point_e\models\configs.py
default_cache_dir,,"@lru_cache()
def default_cache_dir() ->str:
    return os.path.join(os.path.abspath(os.getcwd()), 'point_e_model_cache')
",point_e\models\download.py
fetch_file_cached,"Download the file at the given URL into a local file and return the path.
If cache_dir is specified, it will be used to download the files.
Otherwise, default_cache_dir() is used.","def fetch_file_cached(url: str, progress: bool=True, cache_dir: Optional[
    str]=None, chunk_size: int=4096) ->str:
    """"""""""""
    if cache_dir is None:
        cache_dir = default_cache_dir()
    os.makedirs(cache_dir, exist_ok=True)
    local_path = os.path.join(cache_dir, url.split('/')[-1])
    if os.path.exists(local_path):
        return local_path
    response = requests.get(url, stream=True)
    size = int(response.headers.get('content-length', '0'))
    with FileLock(local_path + '.lock'):
        if progress:
            pbar = tqdm(total=size, unit='iB', unit_scale=True)
        tmp_path = local_path + '.tmp'
        with open(tmp_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size):
                if progress:
                    pbar.update(len(chunk))
                f.write(chunk)
        os.rename(tmp_path, local_path)
        if progress:
            pbar.close()
        return local_path
",point_e\models\download.py
load_checkpoint,,"def load_checkpoint(checkpoint_name: str, device: torch.device, progress:
    bool=True, cache_dir: Optional[str]=None, chunk_size: int=4096) ->Dict[
    str, torch.Tensor]:
    if checkpoint_name not in MODEL_PATHS:
        raise ValueError(
            f'Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}.'
            )
    path = fetch_file_cached(MODEL_PATHS[checkpoint_name], progress=
        progress, cache_dir=cache_dir, chunk_size=chunk_size)
    return torch.load(path, map_location=device)
",point_e\models\download.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_data: int,
    width: int, heads: int, init_scale: float, data_width: Optional[int]=None):
    super().__init__()
    self.n_data = n_data
    self.width = width
    self.heads = heads
    self.data_width = width if data_width is None else data_width
    self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
    self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=
        dtype)
    self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
    self.attention = QKVMultiheadCrossAttention(device=device, dtype=dtype,
        heads=heads, n_data=n_data)
    init_linear(self.c_q, init_scale)
    init_linear(self.c_kv, init_scale)
    init_linear(self.c_proj, init_scale)
",point_e\models\perceiver.py
forward,,"def forward(self, x, data):
    x = self.c_q(x)
    data = self.c_kv(data)
    x = checkpoint(self.attention, (x, data), (), True)
    x = self.c_proj(x)
    return x
",point_e\models\perceiver.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
    n_data: int):
    super().__init__()
    self.device = device
    self.dtype = dtype
    self.heads = heads
    self.n_data = n_data
",point_e\models\perceiver.py
forward,,"def forward(self, q, kv):
    _, n_ctx, _ = q.shape
    bs, n_data, width = kv.shape
    attn_ch = width // self.heads // 2
    scale = 1 / math.sqrt(math.sqrt(attn_ch))
    q = q.view(bs, n_ctx, self.heads, -1)
    kv = kv.view(bs, n_data, self.heads, -1)
    k, v = torch.split(kv, attn_ch, dim=-1)
    weight = torch.einsum('bthc,bshc->bhts', q * scale, k * scale)
    wdtype = weight.dtype
    weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
    return torch.einsum('bhts,bshc->bthc', weight, v).reshape(bs, n_ctx, -1)
",point_e\models\perceiver.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_data: int,
    width: int, heads: int, data_width: Optional[int]=None, init_scale:
    float=1.0):
    super().__init__()
    if data_width is None:
        data_width = width
    self.attn = MultiheadCrossAttention(device=device, dtype=dtype, n_data=
        n_data, width=width, heads=heads, data_width=data_width, init_scale
        =init_scale)
    self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
    self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
    self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=
        init_scale)
    self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
",point_e\models\perceiver.py
forward,,"def forward(self, x: torch.Tensor, data: torch.Tensor):
    x = x + self.attn(self.ln_1(x), self.ln_2(data))
    x = x + self.mlp(self.ln_3(x))
    return x
",point_e\models\perceiver.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_data: int,
    width: int, layers: int, heads: int, init_scale: float=0.25, data_width:
    Optional[int]=None):
    super().__init__()
    self.width = width
    self.layers = layers
    init_scale = init_scale * math.sqrt(1.0 / width)
    self.resblocks = nn.ModuleList([ResidualCrossAttentionBlock(device=
        device, dtype=dtype, n_data=n_data, width=width, heads=heads,
        init_scale=init_scale, data_width=data_width) for _ in range(layers)])
",point_e\models\perceiver.py
forward,,"def forward(self, x: torch.Tensor, data: torch.Tensor):
    for block in self.resblocks:
        x = block(x, data)
    return x
",point_e\models\perceiver.py
_image_to_pil,,"def _image_to_pil(obj: Optional[ImageType]) ->Image.Image:
    if obj is None:
        return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
    if isinstance(obj, np.ndarray):
        return Image.fromarray(obj.astype(np.uint8))
    elif isinstance(obj, torch.Tensor):
        return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
    else:
        return obj
",point_e\models\pretrained_clip.py
__init__,,"def __init__(self, device: torch.device, dtype: Optional[torch.dtype]=torch
    .float32, ensure_used_params: bool=True, clip_name: str='ViT-L/14',
    cache_dir: Optional[str]=None):
    super().__init__()
    assert clip_name in ['ViT-L/14', 'ViT-B/32']
    self.device = device
    self.ensure_used_params = ensure_used_params
    import clip
    self.clip_model, self.preprocess = clip.load(clip_name, device=device,
        download_root=cache_dir or default_cache_dir())
    self.clip_name = clip_name
    if dtype is not None:
        self.clip_model.to(dtype)
    self._tokenize = clip.tokenize
",point_e\models\pretrained_clip.py
feature_dim,,"@property
def feature_dim(self) ->int:
    if self.clip_name == 'ViT-L/14':
        return 768
    else:
        return 512
",point_e\models\pretrained_clip.py
grid_size,,"@property
def grid_size(self) ->int:
    if self.clip_name == 'ViT-L/14':
        return 16
    else:
        return 7
",point_e\models\pretrained_clip.py
grid_feature_dim,,"@property
def grid_feature_dim(self) ->int:
    if self.clip_name == 'ViT-L/14':
        return 1024
    else:
        return 768
",point_e\models\pretrained_clip.py
forward,"Generate a batch of embeddings from a mixture of images, texts,
precomputed embeddings, and possibly empty values.

For each batch element, at most one of images, texts, and embeddings
should have a non-None value. Embeddings from multiple modalities
cannot be mixed for a single batch element. If no modality is provided,
a zero embedding will be used for the batch element.","def forward(self, batch_size: int, images: Optional[Iterable[Optional[
    ImageType]]]=None, texts: Optional[Iterable[Optional[str]]]=None,
    embeddings: Optional[Iterable[Optional[torch.Tensor]]]=None
    ) ->torch.Tensor:
    """"""""""""
    image_seq = [None] * batch_size if images is None else list(images)
    text_seq = [None] * batch_size if texts is None else list(texts)
    embedding_seq = [None] * batch_size if embeddings is None else list(
        embeddings)
    assert len(image_seq
        ) == batch_size, 'number of images should match batch size'
    assert len(text_seq
        ) == batch_size, 'number of texts should match batch size'
    assert len(embedding_seq
        ) == batch_size, 'number of embeddings should match batch size'
    if self.ensure_used_params:
        return self._static_multimodal_embed(images=image_seq, texts=
            text_seq, embeddings=embedding_seq)
    result = torch.zeros((batch_size, self.feature_dim), device=self.device)
    index_images = []
    index_texts = []
    for i, (image, text, emb) in enumerate(zip(image_seq, text_seq,
        embedding_seq)):
        assert sum([int(image is not None), int(text is not None), int(emb
             is not None)]
            ) < 2, 'only one modality may be non-None per batch element'
        if image is not None:
            index_images.append((i, image))
        elif text is not None:
            index_texts.append((i, text))
        elif emb is not None:
            result[i] = emb.to(result)
    if len(index_images):
        embs = self.embed_images(img for _, img in index_images)
        for (i, _), emb in zip(index_images, embs):
            result[i] = emb.to(result)
    if len(index_texts):
        embs = self.embed_text(text for _, text in index_texts)
        for (i, _), emb in zip(index_texts, embs):
            result[i] = emb.to(result)
    return result
",point_e\models\pretrained_clip.py
_static_multimodal_embed,"Like forward(), but always runs all encoders to ensure that
the forward graph looks the same on every rank.","def _static_multimodal_embed(self, images: List[Optional[ImageType]]=None,
    texts: List[Optional[str]]=None, embeddings: List[Optional[torch.Tensor
    ]]=None) ->torch.Tensor:
    """"""""""""
    image_emb = self.embed_images(images)
    text_emb = self.embed_text(t if t else '' for t in texts)
    joined_embs = torch.stack([(emb.to(device=self.device, dtype=torch.
        float32) if emb is not None else torch.zeros(self.feature_dim,
        device=self.device)) for emb in embeddings], dim=0)
    image_flag = torch.tensor([(x is not None) for x in images], device=
        self.device)[:, None].expand_as(image_emb)
    text_flag = torch.tensor([(x is not None) for x in texts], device=self.
        device)[:, None].expand_as(image_emb)
    emb_flag = torch.tensor([(x is not None) for x in embeddings], device=
        self.device)[:, None].expand_as(image_emb)
    return image_flag.float() * image_emb + text_flag.float(
        ) * text_emb + emb_flag.float(
        ) * joined_embs + self.clip_model.logit_scale * 0
",point_e\models\pretrained_clip.py
embed_images,":param xs: N images, stored as numpy arrays, tensors, or PIL images.
:return: an [N x D] tensor of features.","def embed_images(self, xs: Iterable[Optional[ImageType]]) ->torch.Tensor:
    """"""""""""
    clip_inputs = self.images_to_tensor(xs)
    results = self.clip_model.encode_image(clip_inputs).float()
    return results / torch.linalg.norm(results, dim=-1, keepdim=True)
",point_e\models\pretrained_clip.py
embed_text,Embed text prompts as an [N x D] tensor.,"def embed_text(self, prompts: Iterable[str]) ->torch.Tensor:
    """"""""""""
    enc = self.clip_model.encode_text(self._tokenize(list(prompts),
        truncate=True).to(self.device)).float()
    return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
",point_e\models\pretrained_clip.py
embed_images_grid,"Embed images into latent grids.

:param xs: an iterable of images to embed.
:return: a tensor of shape [N x C x L], where L = self.grid_size**2.","def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) ->torch.Tensor:
    """"""""""""
    if self.ensure_used_params:
        extra_value = 0.0
        for p in self.parameters():
            extra_value = extra_value + p.mean() * 0.0
    else:
        extra_value = 0.0
    x = self.images_to_tensor(xs).to(self.clip_model.dtype)
    vt = self.clip_model.visual
    x = vt.conv1(x)
    x = x.reshape(x.shape[0], x.shape[1], -1)
    x = x.permute(0, 2, 1)
    x = torch.cat([vt.class_embedding.to(x.dtype) + torch.zeros(x.shape[0],
        1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
    x = x + vt.positional_embedding.to(x.dtype)
    x = vt.ln_pre(x)
    x = x.permute(1, 0, 2)
    x = vt.transformer(x)
    x = x.permute(1, 2, 0)
    return x[..., 1:].contiguous().float() + extra_value
",point_e\models\pretrained_clip.py
images_to_tensor,,"def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) ->torch.Tensor:
    return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0
        ).to(self.device)
",point_e\models\pretrained_clip.py
__init__,,"def __init__(self, device: torch.device, **kwargs):
    self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **
        kwargs)
    for parameter in self.model.parameters():
        parameter.requires_grad_(False)
",point_e\models\pretrained_clip.py
feature_dim,,"@property
def feature_dim(self) ->int:
    return self.model.feature_dim
",point_e\models\pretrained_clip.py
grid_size,,"@property
def grid_size(self) ->int:
    return self.model.grid_size
",point_e\models\pretrained_clip.py
grid_feature_dim,,"@property
def grid_feature_dim(self) ->int:
    return self.model.grid_feature_dim
",point_e\models\pretrained_clip.py
__call__,,"def __call__(self, batch_size: int, images: Optional[Iterable[Optional[
    ImageType]]]=None, texts: Optional[Iterable[Optional[str]]]=None,
    embeddings: Optional[Iterable[Optional[torch.Tensor]]]=None
    ) ->torch.Tensor:
    return self.model(batch_size=batch_size, images=images, texts=texts,
        embeddings=embeddings)
",point_e\models\pretrained_clip.py
embed_images,,"def embed_images(self, xs: Iterable[Optional[ImageType]]) ->torch.Tensor:
    with torch.no_grad():
        return self.model.embed_images(xs)
",point_e\models\pretrained_clip.py
embed_text,,"def embed_text(self, prompts: Iterable[str]) ->torch.Tensor:
    with torch.no_grad():
        return self.model.embed_text(prompts)
",point_e\models\pretrained_clip.py
embed_images_grid,,"def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) ->torch.Tensor:
    with torch.no_grad():
        return self.model.embed_images_grid(xs)
",point_e\models\pretrained_clip.py
device,Get the device that should be used for input tensors.,"@property
@abstractmethod
def device(self) ->torch.device:
    """"""""""""
",point_e\models\sdf.py
default_batch_size,"Get a reasonable default number of query points for the model.
In some cases, this might be the only supported size.","@property
@abstractmethod
def default_batch_size(self) ->int:
    """"""""""""
",point_e\models\sdf.py
encode_point_clouds,"Encode a batch of point clouds to cache part of the SDF calculation
done by forward().

:param point_clouds: a batch of [batch x 3 x N] points.
:return: a state representing the encoded point cloud batch.","@abstractmethod
def encode_point_clouds(self, point_clouds: torch.Tensor) ->Dict[str, torch
    .Tensor]:
    """"""""""""
",point_e\models\sdf.py
forward,"Predict the SDF at the coordinates x, given a batch of point clouds.

Either point_clouds or encoded should be passed. Only exactly one of
these arguments should be None.

:param x: a [batch x 3 x N'] tensor of query points.
:param point_clouds: a [batch x 3 x N] batch of point clouds.
:param encoded: the result of calling encode_point_clouds().
:return: a [batch x N'] tensor of SDF predictions.","def forward(self, x: torch.Tensor, point_clouds: Optional[torch.Tensor]=
    None, encoded: Optional[Dict[str, torch.Tensor]]=None) ->torch.Tensor:
    """"""""""""
    assert point_clouds is not None or encoded is not None
    assert point_clouds is None or encoded is None
    if point_clouds is not None:
        encoded = self.encode_point_clouds(point_clouds)
    return self.predict_sdf(x, encoded)
",point_e\models\sdf.py
predict_sdf,"Predict the SDF at the query points given the encoded point clouds.

Each query point should be treated independently, only conditioning on
the point clouds themselves.","@abstractmethod
def predict_sdf(self, x: torch.Tensor, encoded: Optional[Dict[str, torch.
    Tensor]]) ->torch.Tensor:
    """"""""""""
",point_e\models\sdf.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int=
    4096, width: int=512, encoder_layers: int=12, encoder_heads: int=8,
    decoder_layers: int=4, decoder_heads: int=8, init_scale: float=0.25):
    super().__init__()
    self._device = device
    self.n_ctx = n_ctx
    self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
    self.encoder = Transformer(device=device, dtype=dtype, n_ctx=n_ctx,
        width=width, layers=encoder_layers, heads=encoder_heads, init_scale
        =init_scale)
    self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
    self.decoder = SimplePerceiver(device=device, dtype=dtype, n_data=n_ctx,
        width=width, layers=decoder_layers, heads=decoder_heads, init_scale
        =init_scale)
    self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
    self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
",point_e\models\sdf.py
device,,"@property
def device(self) ->torch.device:
    return self._device
",point_e\models\sdf.py
default_batch_size,,"@property
def default_batch_size(self) ->int:
    return self.n_query
",point_e\models\sdf.py
encode_point_clouds,,"def encode_point_clouds(self, point_clouds: torch.Tensor) ->Dict[str, torch
    .Tensor]:
    h = self.encoder_input_proj(point_clouds.permute(0, 2, 1))
    h = self.encoder(h)
    return dict(latents=h)
",point_e\models\sdf.py
predict_sdf,,"def predict_sdf(self, x: torch.Tensor, encoded: Optional[Dict[str, torch.
    Tensor]]) ->torch.Tensor:
    data = encoded['latents']
    x = self.decoder_input_proj(x.permute(0, 2, 1))
    x = self.decoder(x, data)
    x = self.ln_post(x)
    x = self.output_proj(x)
    return x[..., 0]
",point_e\models\sdf.py
init_linear,,"def init_linear(l, stddev):
    nn.init.normal_(l.weight, std=stddev)
    if l.bias is not None:
        nn.init.constant_(l.bias, 0.0)
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int,
    width: int, heads: int, init_scale: float):
    super().__init__()
    self.n_ctx = n_ctx
    self.width = width
    self.heads = heads
    self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
    self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
    self.attention = QKVMultiheadAttention(device=device, dtype=dtype,
        heads=heads, n_ctx=n_ctx)
    init_linear(self.c_qkv, init_scale)
    init_linear(self.c_proj, init_scale)
",point_e\models\transformer.py
forward,,"def forward(self, x):
    x = self.c_qkv(x)
    x = checkpoint(self.attention, (x,), (), True)
    x = self.c_proj(x)
    return x
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int,
    init_scale: float):
    super().__init__()
    self.width = width
    self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
    self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
    self.gelu = nn.GELU()
    init_linear(self.c_fc, init_scale)
    init_linear(self.c_proj, init_scale)
",point_e\models\transformer.py
forward,,"def forward(self, x):
    return self.c_proj(self.gelu(self.c_fc(x)))
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
    n_ctx: int):
    super().__init__()
    self.device = device
    self.dtype = dtype
    self.heads = heads
    self.n_ctx = n_ctx
",point_e\models\transformer.py
forward,,"def forward(self, qkv):
    bs, n_ctx, width = qkv.shape
    attn_ch = width // self.heads // 3
    scale = 1 / math.sqrt(math.sqrt(attn_ch))
    qkv = qkv.view(bs, n_ctx, self.heads, -1)
    q, k, v = torch.split(qkv, attn_ch, dim=-1)
    weight = torch.einsum('bthc,bshc->bhts', q * scale, k * scale)
    wdtype = weight.dtype
    weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
    return torch.einsum('bhts,bshc->bthc', weight, v).reshape(bs, n_ctx, -1)
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int,
    width: int, heads: int, init_scale: float=1.0):
    super().__init__()
    self.attn = MultiheadAttention(device=device, dtype=dtype, n_ctx=n_ctx,
        width=width, heads=heads, init_scale=init_scale)
    self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
    self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=
        init_scale)
    self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
",point_e\models\transformer.py
forward,,"def forward(self, x: torch.Tensor):
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int,
    width: int, layers: int, heads: int, init_scale: float=0.25):
    super().__init__()
    self.n_ctx = n_ctx
    self.width = width
    self.layers = layers
    init_scale = init_scale * math.sqrt(1.0 / width)
    self.resblocks = nn.ModuleList([ResidualAttentionBlock(device=device,
        dtype=dtype, n_ctx=n_ctx, width=width, heads=heads, init_scale=
        init_scale) for _ in range(layers)])
",point_e\models\transformer.py
forward,,"def forward(self, x: torch.Tensor):
    for block in self.resblocks:
        x = block(x)
    return x
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype,
    input_channels: int=3, output_channels: int=3, n_ctx: int=1024, width:
    int=512, layers: int=12, heads: int=8, init_scale: float=0.25,
    time_token_cond: bool=False):
    super().__init__()
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.n_ctx = n_ctx
    self.time_token_cond = time_token_cond
    self.time_embed = MLP(device=device, dtype=dtype, width=width,
        init_scale=init_scale * math.sqrt(1.0 / width))
    self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
    self.backbone = Transformer(device=device, dtype=dtype, n_ctx=n_ctx +
        int(time_token_cond), width=width, layers=layers, heads=heads,
        init_scale=init_scale)
    self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
    self.input_proj = nn.Linear(input_channels, width, device=device, dtype
        =dtype)
    self.output_proj = nn.Linear(width, output_channels, device=device,
        dtype=dtype)
    with torch.no_grad():
        self.output_proj.weight.zero_()
        self.output_proj.bias.zero_()
",point_e\models\transformer.py
forward,":param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:return: an [N x C' x T] tensor.","def forward(self, x: torch.Tensor, t: torch.Tensor):
    """"""""""""
    assert x.shape[-1] == self.n_ctx
    t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
    return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
",point_e\models\transformer.py
_forward_with_cond,,"def _forward_with_cond(self, x: torch.Tensor, cond_as_token: List[Tuple[
    torch.Tensor, bool]]) ->torch.Tensor:
    h = self.input_proj(x.permute(0, 2, 1))
    for emb, as_token in cond_as_token:
        if not as_token:
            h = h + emb[:, None]
    extra_tokens = [(emb[:, None] if len(emb.shape) == 2 else emb) for emb,
        as_token in cond_as_token if as_token]
    if len(extra_tokens):
        h = torch.cat(extra_tokens + [h], dim=1)
    h = self.ln_pre(h)
    h = self.backbone(h)
    h = self.ln_post(h)
    if len(extra_tokens):
        h = h[:, sum(h.shape[1] for h in extra_tokens):]
    h = self.output_proj(h)
    return h.permute(0, 2, 1)
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int=
    1024, token_cond: bool=False, cond_drop_prob: float=0.0, frozen_clip:
    bool=True, cache_dir: Optional[str]=None, **kwargs):
    super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + int(
        token_cond), **kwargs)
    self.n_ctx = n_ctx
    self.token_cond = token_cond
    self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device,
        cache_dir=cache_dir)
    self.clip_embed = nn.Linear(self.clip.feature_dim, self.backbone.width,
        device=device, dtype=dtype)
    self.cond_drop_prob = cond_drop_prob
",point_e\models\transformer.py
cached_model_kwargs,,"def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]
    ) ->Dict[str, Any]:
    with torch.no_grad():
        return dict(embeddings=self.clip(batch_size, **model_kwargs))
",point_e\models\transformer.py
forward,":param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param texts: a batch of texts to condition on.
:param embeddings: a batch of CLIP embeddings to condition on.
:return: an [N x C' x T] tensor.","def forward(self, x: torch.Tensor, t: torch.Tensor, images: Optional[
    Iterable[Optional[ImageType]]]=None, texts: Optional[Iterable[Optional[
    str]]]=None, embeddings: Optional[Iterable[Optional[torch.Tensor]]]=None):
    """"""""""""
    assert x.shape[-1] == self.n_ctx
    t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
    clip_out = self.clip(batch_size=len(x), images=images, texts=texts,
        embeddings=embeddings)
    assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
    if self.training:
        mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
        clip_out = clip_out * mask[:, None].to(clip_out)
    clip_out = math.sqrt(clip_out.shape[1]) * clip_out
    clip_embed = self.clip_embed(clip_out)
    cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
    return self._forward_with_cond(x, cond)
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int=
    1024, cond_drop_prob: float=0.0, frozen_clip: bool=True, cache_dir:
    Optional[str]=None, **kwargs):
    clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device,
        cache_dir=cache_dir)
    super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.
        grid_size ** 2, **kwargs)
    self.n_ctx = n_ctx
    self.clip = clip
    self.clip_embed = nn.Sequential(nn.LayerNorm(normalized_shape=(self.
        clip.grid_feature_dim,), device=device, dtype=dtype), nn.Linear(
        self.clip.grid_feature_dim, self.backbone.width, device=device,
        dtype=dtype))
    self.cond_drop_prob = cond_drop_prob
",point_e\models\transformer.py
cached_model_kwargs,,"def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]
    ) ->Dict[str, Any]:
    _ = batch_size
    with torch.no_grad():
        return dict(embeddings=self.clip.embed_images_grid(model_kwargs[
            'images']))
",point_e\models\transformer.py
forward,":param x: an [N x C x T] tensor.
:param t: an [N] tensor.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C' x T] tensor.","def forward(self, x: torch.Tensor, t: torch.Tensor, images: Optional[
    Iterable[ImageType]]=None, embeddings: Optional[Iterable[torch.Tensor]]
    =None):
    """"""""""""
    assert images is not None or embeddings is not None, 'must specify images or embeddings'
    assert images is None or embeddings is None, 'cannot specify both images and embeddings'
    assert x.shape[-1] == self.n_ctx
    t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
    if images is not None:
        clip_out = self.clip.embed_images_grid(images)
    else:
        clip_out = embeddings
    if self.training:
        mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
        clip_out = clip_out * mask[:, None, None].to(clip_out)
    clip_out = clip_out.permute(0, 2, 1)
    clip_embed = self.clip_embed(clip_out)
    cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
    return self._forward_with_cond(x, cond)
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype,
    cond_input_channels: Optional[int]=None, cond_ctx: int=1024, n_ctx: int
    =4096 - 1024, channel_scales: Optional[Sequence[float]]=None,
    channel_biases: Optional[Sequence[float]]=None, **kwargs):
    super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **
        kwargs)
    self.n_ctx = n_ctx
    self.cond_input_channels = cond_input_channels or self.input_channels
    self.cond_point_proj = nn.Linear(self.cond_input_channels, self.
        backbone.width, device=device, dtype=dtype)
    self.register_buffer('channel_scales', torch.tensor(channel_scales,
        dtype=dtype, device=device) if channel_scales is not None else None)
    self.register_buffer('channel_biases', torch.tensor(channel_biases,
        dtype=dtype, device=device) if channel_biases is not None else None)
",point_e\models\transformer.py
forward,":param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:return: an [N x C3 x T] tensor.","def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
    """"""""""""
    assert x.shape[-1] == self.n_ctx
    t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
    low_res_embed = self._embed_low_res(low_res)
    cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
    return self._forward_with_cond(x, cond)
",point_e\models\transformer.py
_embed_low_res,,"def _embed_low_res(self, x: torch.Tensor) ->torch.Tensor:
    if self.channel_scales is not None:
        x = x * self.channel_scales[None, :, None]
    if self.channel_biases is not None:
        x = x + self.channel_biases[None, :, None]
    return self.cond_point_proj(x.permute(0, 2, 1))
",point_e\models\transformer.py
__init__,,"def __init__(self, *, device: torch.device, dtype: torch.dtype, n_ctx: int=
    4096 - 1024, cond_drop_prob: float=0.0, frozen_clip: bool=True,
    cache_dir: Optional[str]=None, **kwargs):
    clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device,
        cache_dir=cache_dir)
    super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.
        grid_size ** 2, **kwargs)
    self.n_ctx = n_ctx
    self.clip = clip
    self.clip_embed = nn.Sequential(nn.LayerNorm(normalized_shape=(self.
        clip.grid_feature_dim,), device=device, dtype=dtype), nn.Linear(
        self.clip.grid_feature_dim, self.backbone.width, device=device,
        dtype=dtype))
    self.cond_drop_prob = cond_drop_prob
",point_e\models\transformer.py
cached_model_kwargs,,"def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]
    ) ->Dict[str, Any]:
    if 'images' not in model_kwargs:
        zero_emb = torch.zeros([batch_size, self.clip.grid_feature_dim, 
            self.clip.grid_size ** 2], device=next(self.parameters()).device)
        return dict(embeddings=zero_emb, low_res=model_kwargs['low_res'])
    with torch.no_grad():
        return dict(embeddings=self.clip.embed_images_grid(model_kwargs[
            'images']), low_res=model_kwargs['low_res'])
",point_e\models\transformer.py
forward,":param x: an [N x C1 x T] tensor.
:param t: an [N] tensor.
:param low_res: an [N x C2 x T'] tensor of conditioning points.
:param images: a batch of images to condition on.
:param embeddings: a batch of CLIP latent grids to condition on.
:return: an [N x C3 x T] tensor.","def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.
    Tensor, images: Optional[Iterable[ImageType]]=None, embeddings:
    Optional[Iterable[torch.Tensor]]=None):
    """"""""""""
    assert x.shape[-1] == self.n_ctx
    t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
    low_res_embed = self._embed_low_res(low_res)
    if images is not None:
        clip_out = self.clip.embed_images_grid(images)
    elif embeddings is not None:
        clip_out = embeddings
    else:
        clip_out = torch.zeros([len(x), self.clip.grid_feature_dim, self.
            clip.grid_size ** 2], dtype=x.dtype, device=x.device)
    if self.training:
        mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
        clip_out = clip_out * mask[:, None, None].to(clip_out)
    clip_out = clip_out.permute(0, 2, 1)
    clip_embed = self.clip_embed(clip_out)
    cond = [(t_embed, self.time_token_cond), (clip_embed, True), (
        low_res_embed, True)]
    return self._forward_with_cond(x, cond)
",point_e\models\transformer.py
timestep_embedding,"Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
                  These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.","def timestep_embedding(timesteps, dim, max_period=10000):
    """"""""""""
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=
        half, dtype=torch.float32) / half).to(device=timesteps.device)
    args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1]
            )], dim=-1)
    return embedding
",point_e\models\util.py
load,Load the mesh from a .npz file.,"@classmethod
def load(cls, f: Union[str, BinaryIO]) ->'TriMesh':
    """"""""""""
    if isinstance(f, str):
        with open(f, 'rb') as reader:
            return cls.load(reader)
    else:
        obj = np.load(f)
        keys = list(obj.keys())
        verts = obj['verts']
        faces = obj['faces']
        normals = obj['normals'] if 'normals' in keys else None
        vertex_channels = {}
        face_channels = {}
        for key in keys:
            if key.startswith('v_'):
                vertex_channels[key[2:]] = obj[key]
            elif key.startswith('f_'):
                face_channels[key[2:]] = obj[key]
        return cls(verts=verts, faces=faces, normals=normals,
            vertex_channels=vertex_channels, face_channels=face_channels)
",point_e\util\mesh.py
save,Save the mesh to a .npz file.,"def save(self, f: Union[str, BinaryIO]):
    """"""""""""
    if isinstance(f, str):
        with open(f, 'wb') as writer:
            self.save(writer)
    else:
        obj_dict = dict(verts=self.verts, faces=self.faces)
        if self.normals is not None:
            obj_dict['normals'] = self.normals
        for k, v in self.vertex_channels.items():
            obj_dict[f'v_{k}'] = v
        for k, v in self.face_channels.items():
            obj_dict[f'f_{k}'] = v
        np.savez(f, **obj_dict)
",point_e\util\mesh.py
has_vertex_colors,,"def has_vertex_colors(self) ->bool:
    return self.vertex_channels is not None and all(x in self.
        vertex_channels for x in 'RGB')
",point_e\util\mesh.py
write_ply,,"def write_ply(self, raw_f: BinaryIO):
    write_ply(raw_f, coords=self.verts, rgb=np.stack([self.vertex_channels[
        x] for x in 'RGB'], axis=1) if self.has_vertex_colors() else None,
        faces=self.faces)
",point_e\util\mesh.py
marching_cubes_mesh,"Run marching cubes on the SDF predicted from a point cloud to produce a
mesh representing the 3D surface.

:param pc: the point cloud to apply marching cubes to.
:param model: the model to use to predict SDF values.
:param grid_size: the number of samples along each axis. A total of
                  grid_size**3 function evaluations are performed.
:param side_length: the size of the cube containing the model, which is
                    assumed to be centered at the origin.
:param fill_vertex_channels: if True, use the nearest neighbor of each mesh
                             vertex in the point cloud to compute vertex
                             data (e.g. colors).","def marching_cubes_mesh(pc: PointCloud, model: PointCloudSDFModel,
    batch_size: int=4096, grid_size: int=128, side_length: float=1.02,
    fill_vertex_channels: bool=True, progress: bool=False) ->TriMesh:
    """"""""""""
    voxel_size = side_length / (grid_size - 1)
    min_coord = -side_length / 2

    def int_coord_to_float(int_coords: torch.Tensor) ->torch.Tensor:
        return int_coords.float() * voxel_size + min_coord
    with torch.no_grad():
        cond = model.encode_point_clouds(torch.from_numpy(pc.coords).
            permute(1, 0).to(model.device)[None])
    indices = range(0, grid_size ** 3, batch_size)
    if progress:
        indices = tqdm(indices)
    volume = []
    for i in indices:
        indices = torch.arange(i, min(i + batch_size, grid_size ** 3), step
            =1, dtype=torch.int64, device=model.device)
        zs = int_coord_to_float(indices % grid_size)
        ys = int_coord_to_float(torch.div(indices, grid_size, rounding_mode
            ='trunc') % grid_size)
        xs = int_coord_to_float(torch.div(indices, grid_size ** 2,
            rounding_mode='trunc'))
        coords = torch.stack([xs, ys, zs], dim=0)
        with torch.no_grad():
            volume.append(model(coords[None], encoded=cond)[0])
    volume_np = torch.cat(volume).view(grid_size, grid_size, grid_size).cpu(
        ).numpy()
    if np.all(volume_np < 0) or np.all(volume_np > 0):
        volume_np -= np.mean(volume_np)
    verts, faces, normals, _ = skimage.measure.marching_cubes(volume=
        volume_np, level=0, allow_degenerate=False, spacing=(voxel_size,) * 3)
    old_f1 = faces[:, 0].copy()
    faces[:, 0] = faces[:, 1]
    faces[:, 1] = old_f1
    verts += min_coord
    return TriMesh(verts=verts, faces=faces, normals=normals,
        vertex_channels=None if not fill_vertex_channels else
        _nearest_vertex_channels(pc, verts))
",point_e\util\pc_to_mesh.py
_nearest_vertex_channels,,"def _nearest_vertex_channels(pc: PointCloud, verts: np.ndarray) ->Dict[str,
    np.ndarray]:
    nearest = pc.nearest_points(verts)
    return {ch: arr[nearest] for ch, arr in pc.channels.items()}
",point_e\util\pc_to_mesh.py
plot_point_cloud,"Render a point cloud as a plot to the given image path.

:param pc: the PointCloud to plot.
:param image_path: the path to save the image, with a file extension.
:param color: if True, show the RGB colors from the point cloud.
:param grid_size: the number of random rotations to render.","def plot_point_cloud(pc: PointCloud, color: bool=True, grid_size: int=1,
    fixed_bounds: Optional[Tuple[Tuple[float, float, float], Tuple[float,
    float, float]]]=((-0.75, -0.75, -0.75), (0.75, 0.75, 0.75))):
    """"""""""""
    fig = plt.figure(figsize=(8, 8))
    for i in range(grid_size):
        for j in range(grid_size):
            ax = fig.add_subplot(grid_size, grid_size, 1 + j + i *
                grid_size, projection='3d')
            color_args = {}
            if color:
                color_args['c'] = np.stack([pc.channels['R'], pc.channels[
                    'G'], pc.channels['B']], axis=-1)
            c = pc.coords
            if grid_size > 1:
                theta = np.pi * 2 * (i * grid_size + j) / grid_size ** 2
                rotation = np.array([[np.cos(theta), -np.sin(theta), 0.0],
                    [np.sin(theta), np.cos(theta), 0.0], [0.0, 0.0, 1.0]])
                c = c @ rotation
            ax.scatter(c[:, 0], c[:, 1], c[:, 2], **color_args)
            if fixed_bounds is None:
                min_point = c.min(0)
                max_point = c.max(0)
                size = (max_point - min_point).max() / 2
                center = (min_point + max_point) / 2
                ax.set_xlim3d(center[0] - size, center[0] + size)
                ax.set_ylim3d(center[1] - size, center[1] + size)
                ax.set_zlim3d(center[2] - size, center[2] + size)
            else:
                ax.set_xlim3d(fixed_bounds[0][0], fixed_bounds[1][0])
                ax.set_ylim3d(fixed_bounds[0][1], fixed_bounds[1][1])
                ax.set_zlim3d(fixed_bounds[0][2], fixed_bounds[1][2])
    return fig
",point_e\util\plotting.py
write_ply,"Write a PLY file for a mesh or a point cloud.

:param coords: an [N x 3] array of floating point coordinates.
:param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0].
:param faces: an [N x 3] array of triangles encoded as integer indices.","def write_ply(raw_f: BinaryIO, coords: np.ndarray, rgb: Optional[np.ndarray
    ]=None, faces: Optional[np.ndarray]=None):
    """"""""""""
    with buffered_writer(raw_f) as f:
        f.write(b'ply\n')
        f.write(b'format binary_little_endian 1.0\n')
        f.write(bytes(f'element vertex {len(coords)}\n', 'ascii'))
        f.write(b'property float x\n')
        f.write(b'property float y\n')
        f.write(b'property float z\n')
        if rgb is not None:
            f.write(b'property uchar red\n')
            f.write(b'property uchar green\n')
            f.write(b'property uchar blue\n')
        if faces is not None:
            f.write(bytes(f'element face {len(faces)}\n', 'ascii'))
            f.write(b'property list uchar int vertex_index\n')
        f.write(b'end_header\n')
        if rgb is not None:
            rgb = (rgb * 255.499).round().astype(int)
            vertices = [(*coord, *rgb) for coord, rgb in zip(coords.tolist(
                ), rgb.tolist())]
            format = struct.Struct('<3f3B')
            for item in vertices:
                f.write(format.pack(*item))
        else:
            format = struct.Struct('<3f')
            for vertex in coords.tolist():
                f.write(format.pack(*vertex))
        if faces is not None:
            format = struct.Struct('<B3I')
            for tri in faces.tolist():
                f.write(format.pack(len(tri), *tri))
",point_e\util\ply_util.py
buffered_writer,,"@contextmanager
def buffered_writer(raw_f: BinaryIO) ->Iterator[io.BufferedIOBase]:
    if isinstance(raw_f, io.BufferedIOBase):
        yield raw_f
    else:
        f = io.BufferedWriter(raw_f)
        yield f
        f.flush()
",point_e\util\ply_util.py
preprocess,,"def preprocess(data, channel):
    if channel in COLORS:
        return np.round(data * 255.0)
    return data
",point_e\util\point_cloud.py
load,Load the point cloud from a .npz file.,"@classmethod
def load(cls, f: Union[str, BinaryIO]) ->'PointCloud':
    """"""""""""
    if isinstance(f, str):
        with open(f, 'rb') as reader:
            return cls.load(reader)
    else:
        obj = np.load(f)
        keys = list(obj.keys())
        return PointCloud(coords=obj['coords'], channels={k: obj[k] for k in
            keys if k != 'coords'})
",point_e\util\point_cloud.py
save,Save the point cloud to a .npz file.,"def save(self, f: Union[str, BinaryIO]):
    """"""""""""
    if isinstance(f, str):
        with open(f, 'wb') as writer:
            self.save(writer)
    else:
        np.savez(f, coords=self.coords, **self.channels)
",point_e\util\point_cloud.py
write_ply,,"def write_ply(self, raw_f: BinaryIO):
    write_ply(raw_f, coords=self.coords, rgb=np.stack([self.channels[x] for
        x in 'RGB'], axis=1) if all(x in self.channels for x in 'RGB') else
        None)
",point_e\util\point_cloud.py
random_sample,"Sample a random subset of this PointCloud.

:param num_points: maximum number of points to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
         the current number of points.","def random_sample(self, num_points: int, **subsample_kwargs) ->'PointCloud':
    """"""""""""
    if len(self.coords) <= num_points:
        return self
    indices = np.random.choice(len(self.coords), size=(num_points,),
        replace=False)
    return self.subsample(indices, **subsample_kwargs)
",point_e\util\point_cloud.py
farthest_point_sample,"Sample a subset of the point cloud that is evenly distributed in space.

First, a random point is selected. Then each successive point is chosen
such that it is furthest from the currently selected points.

The time complexity of this operation is O(NM), where N is the original
number of points and M is the reduced number. Therefore, performance
can be improved by randomly subsampling points with random_sample()
before running farthest_point_sample().

:param num_points: maximum number of points to sample.
:param init_idx: if specified, the first point to sample.
:param subsample_kwargs: arguments to self.subsample().
:return: a reduced PointCloud, or self if num_points is not less than
         the current number of points.","def farthest_point_sample(self, num_points: int, init_idx: Optional[int]=
    None, **subsample_kwargs) ->'PointCloud':
    """"""""""""
    if len(self.coords) <= num_points:
        return self
    init_idx = random.randrange(len(self.coords)
        ) if init_idx is None else init_idx
    indices = np.zeros([num_points], dtype=np.int64)
    indices[0] = init_idx
    sq_norms = np.sum(self.coords ** 2, axis=-1)

    def compute_dists(idx: int):
        return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
    cur_dists = compute_dists(init_idx)
    for i in range(1, num_points):
        idx = np.argmax(cur_dists)
        indices[i] = idx
        cur_dists = np.minimum(cur_dists, compute_dists(idx))
    return self.subsample(indices, **subsample_kwargs)
",point_e\util\point_cloud.py
subsample,,"def subsample(self, indices: np.ndarray, average_neighbors: bool=False
    ) ->'PointCloud':
    if not average_neighbors:
        return PointCloud(coords=self.coords[indices], channels={k: v[
            indices] for k, v in self.channels.items()})
    new_coords = self.coords[indices]
    neighbor_indices = PointCloud(coords=new_coords, channels={}
        ).nearest_points(self.coords)
    neighbor_indices[indices] = np.arange(len(indices))
    new_channels = {}
    for k, v in self.channels.items():
        v_sum = np.zeros_like(v[:len(indices)])
        v_count = np.zeros_like(v[:len(indices)])
        np.add.at(v_sum, neighbor_indices, v)
        np.add.at(v_count, neighbor_indices, 1)
        new_channels[k] = v_sum / v_count
    return PointCloud(coords=new_coords, channels=new_channels)
",point_e\util\point_cloud.py
select_channels,,"def select_channels(self, channel_names: List[str]) ->np.ndarray:
    data = np.stack([preprocess(self.channels[name], name) for name in
        channel_names], axis=-1)
    return data
",point_e\util\point_cloud.py
nearest_points,"For each point in another set of points, compute the point in this
pointcloud which is closest.

:param points: an [N x 3] array of points.
:param batch_size: the number of neighbor distances to compute at once.
                   Smaller values save memory, while larger values may
                   make the computation faster.
:return: an [N] array of indices into self.coords.","def nearest_points(self, points: np.ndarray, batch_size: int=16384
    ) ->np.ndarray:
    """"""""""""
    norms = np.sum(self.coords ** 2, axis=-1)
    all_indices = []
    for i in range(0, len(points), batch_size):
        batch = points[i:i + batch_size]
        dists = norms + np.sum(batch ** 2, axis=-1)[:, None] - 2 * (batch @
            self.coords.T)
        all_indices.append(np.argmin(dists, axis=-1))
    return np.concatenate(all_indices, axis=0)
",point_e\util\point_cloud.py
combine,,"def combine(self, other: 'PointCloud') ->'PointCloud':
    assert self.channels.keys() == other.channels.keys()
    return PointCloud(coords=np.concatenate([self.coords, other.coords],
        axis=0), channels={k: np.concatenate([v, other.channels[k]], axis=0
        ) for k, v in self.channels.items()})
",point_e\util\point_cloud.py