Spaces:
Runtime error
Runtime error
# Copyright 2023 Adobe Research. All rights reserved. | |
# To view a copy of the license, visit LICENSE.md. | |
import torch | |
def project_to_subspaces(latent: torch.Tensor, basis: torch.Tensor, | |
repurposed_dims: torch.Tensor, base_dims: torch.Tensor = None, | |
step_size=None, mean=None): | |
""" | |
Project latent on the base subspace (Z_base) - spanned by the base_dims. | |
Then, traverses the projected latent along the repurposed directions. | |
If the step_size parameter can be interpreted as some 1D structure, | |
then traversal is performed separately for each repurposed dim with these as step sizes. | |
Otherwise, it defines a joint traversal of multiple dimensions at once. | |
Usually, it would be 3D, so the output can be visualized in a 2D image grid. | |
Returns: | |
traversals.shape, 1D case -[num_steps, num_repurposed, shape of input] | |
traversals.shape, ND case -[num_steps_1, ..., num_steps_N, shape of input] | |
""" | |
if type(latent) == list: | |
if len(latent) != 1: | |
raise ValueError('Latent wrapped by list should be of length 1') | |
latent = latent[0] | |
latent_in_w = False | |
if latent.dim() == 2: | |
# Lift to W+ just for now | |
latent = w_to_wplus(latent) | |
latent_in_w = True | |
elif latent.dim() != 3: | |
raise ValueError('Latent is expected to be 2D (W space) or 3D (W+ space)') | |
latent_dim = latent.shape[-1] | |
if base_dims is None: | |
# Take all non-repurposed dims to span the base subspace -- default mode | |
base_dims = torch.Tensor([x for x in range(latent_dim) if x not in repurposed_dims]) | |
# Use values instead of boolean to change order as needed | |
repurposed_directions = basis[:, repurposed_dims.numpy()] | |
base_directions = basis[:, base_dims.numpy()] | |
projected_latent = latent @ base_directions | |
base_latent = projected_latent @ base_directions.T | |
if mean is not None: | |
base_latent += (mean @ repurposed_directions) @ repurposed_directions.T | |
if step_size is None: | |
if latent_in_w: | |
base_latent = wplus_to_w(base_latent) | |
return base_latent, None | |
if isinstance(step_size, float) or isinstance(step_size, int): | |
step_size = torch.Tensor([step_size]).to(latent.device) | |
repurposed_directions = repurposed_directions.T | |
num_repurposed = len(repurposed_dims) | |
if step_size.dim() == 1: | |
# separate same-sized steps on all dims | |
num_steps = step_size.shape[0] | |
output_shape = [num_steps, num_repurposed, *latent.shape] | |
edits = torch.einsum('a, df -> adf', step_size, repurposed_directions) | |
elif step_size.dim() == 3: | |
# compound steps, on multiple dims | |
steps_in_directions = step_size.shape[:-1] | |
output_shape = [*steps_in_directions, *latent.shape] | |
edits = step_size @ repurposed_directions | |
else: | |
raise NotImplementedError('Cannot edit with these values') | |
edit_latents = base_latent.expand(output_shape) + edits.unsqueeze(2).unsqueeze(2).expand(output_shape) | |
if latent_in_w: | |
# Bring back to W sapce | |
base_latent, edit_latents = wplus_to_w(base_latent), edit_latents[..., 0, :] | |
return base_latent, edit_latents | |
def w_to_wplus(w_latent: torch.Tensor, num_ws=18): | |
return w_latent.unsqueeze(1).repeat([1, num_ws, 1]) | |
def wplus_to_w(latents: torch.Tensor): | |
""" | |
latents is expected to have shape (...,num_ws,512) or | |
""" | |
with torch.no_grad(): | |
_, counts = torch.unique(latents, dim=-2, return_counts=True) | |
if len(counts) != 1: | |
raise ValueError('input latent is not a W code, conversion from W+ is undefined') | |
return latents[..., 0, :] | |