domain-expansion / expansion_utils /latent_operations.py
alvan
Added gradio space for domain expansion
560a1b9
raw
history blame
3.75 kB
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import torch
def project_to_subspaces(latent: torch.Tensor, basis: torch.Tensor,
repurposed_dims: torch.Tensor, base_dims: torch.Tensor = None,
step_size=None, mean=None):
"""
Project latent on the base subspace (Z_base) - spanned by the base_dims.
Then, traverses the projected latent along the repurposed directions.
If the step_size parameter can be interpreted as some 1D structure,
then traversal is performed separately for each repurposed dim with these as step sizes.
Otherwise, it defines a joint traversal of multiple dimensions at once.
Usually, it would be 3D, so the output can be visualized in a 2D image grid.
Returns:
traversals.shape, 1D case -[num_steps, num_repurposed, shape of input]
traversals.shape, ND case -[num_steps_1, ..., num_steps_N, shape of input]
"""
if type(latent) == list:
if len(latent) != 1:
raise ValueError('Latent wrapped by list should be of length 1')
latent = latent[0]
latent_in_w = False
if latent.dim() == 2:
# Lift to W+ just for now
latent = w_to_wplus(latent)
latent_in_w = True
elif latent.dim() != 3:
raise ValueError('Latent is expected to be 2D (W space) or 3D (W+ space)')
latent_dim = latent.shape[-1]
if base_dims is None:
# Take all non-repurposed dims to span the base subspace -- default mode
base_dims = torch.Tensor([x for x in range(latent_dim) if x not in repurposed_dims])
# Use values instead of boolean to change order as needed
repurposed_directions = basis[:, repurposed_dims.numpy()]
base_directions = basis[:, base_dims.numpy()]
projected_latent = latent @ base_directions
base_latent = projected_latent @ base_directions.T
if mean is not None:
base_latent += (mean @ repurposed_directions) @ repurposed_directions.T
if step_size is None:
if latent_in_w:
base_latent = wplus_to_w(base_latent)
return base_latent, None
if isinstance(step_size, float) or isinstance(step_size, int):
step_size = torch.Tensor([step_size]).to(latent.device)
repurposed_directions = repurposed_directions.T
num_repurposed = len(repurposed_dims)
if step_size.dim() == 1:
# separate same-sized steps on all dims
num_steps = step_size.shape[0]
output_shape = [num_steps, num_repurposed, *latent.shape]
edits = torch.einsum('a, df -> adf', step_size, repurposed_directions)
elif step_size.dim() == 3:
# compound steps, on multiple dims
steps_in_directions = step_size.shape[:-1]
output_shape = [*steps_in_directions, *latent.shape]
edits = step_size @ repurposed_directions
else:
raise NotImplementedError('Cannot edit with these values')
edit_latents = base_latent.expand(output_shape) + edits.unsqueeze(2).unsqueeze(2).expand(output_shape)
if latent_in_w:
# Bring back to W sapce
base_latent, edit_latents = wplus_to_w(base_latent), edit_latents[..., 0, :]
return base_latent, edit_latents
def w_to_wplus(w_latent: torch.Tensor, num_ws=18):
return w_latent.unsqueeze(1).repeat([1, num_ws, 1])
def wplus_to_w(latents: torch.Tensor):
"""
latents is expected to have shape (...,num_ws,512) or
"""
with torch.no_grad():
_, counts = torch.unique(latents, dim=-2, return_counts=True)
if len(counts) != 1:
raise ValueError('input latent is not a W code, conversion from W+ is undefined')
return latents[..., 0, :]