File size: 3,747 Bytes
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.

import torch


def project_to_subspaces(latent: torch.Tensor, basis: torch.Tensor,
                        repurposed_dims: torch.Tensor, base_dims: torch.Tensor = None,
                        step_size=None, mean=None):
    """
    Project latent on the base subspace (Z_base) - spanned by the base_dims.
    Then, traverses the projected latent along the repurposed directions.
    If the step_size parameter can be interpreted as some 1D structure, 
    then traversal is performed separately for each repurposed dim with these as step sizes.
    Otherwise, it defines a joint traversal of multiple dimensions at once. 
    Usually, it would be 3D, so the output can be visualized in a 2D image grid.

    Returns:
        traversals.shape, 1D case -[num_steps, num_repurposed, shape of input]
        traversals.shape, ND case -[num_steps_1, ..., num_steps_N, shape of input]
    """
    
    if type(latent) == list:
        if len(latent) != 1:
            raise ValueError('Latent wrapped by list should be of length 1')
        latent = latent[0]
    
    latent_in_w = False
    if latent.dim() == 2:
        # Lift to W+ just for now
        latent = w_to_wplus(latent)
        latent_in_w = True
    elif latent.dim() != 3:
        raise ValueError('Latent is expected to be 2D (W space) or 3D (W+ space)')

    latent_dim = latent.shape[-1]
    
    if base_dims is None:
        # Take all non-repurposed dims to span the base subspace -- default mode
        base_dims = torch.Tensor([x for x in range(latent_dim) if x not in repurposed_dims])

    # Use values instead of boolean to change order as needed
    repurposed_directions = basis[:, repurposed_dims.numpy()]  
    base_directions = basis[:, base_dims.numpy()]

    projected_latent = latent @ base_directions
    base_latent = projected_latent @ base_directions.T

    if mean is not None:
        base_latent += (mean @ repurposed_directions) @ repurposed_directions.T

    if step_size is None:
        if latent_in_w:
            base_latent = wplus_to_w(base_latent)
        return base_latent, None

    if isinstance(step_size, float) or isinstance(step_size, int):
        step_size = torch.Tensor([step_size]).to(latent.device)

    repurposed_directions = repurposed_directions.T

    num_repurposed = len(repurposed_dims)

    if step_size.dim() == 1:
        # separate same-sized steps on all dims
        num_steps = step_size.shape[0]
        output_shape = [num_steps, num_repurposed, *latent.shape]
        edits = torch.einsum('a, df -> adf', step_size, repurposed_directions)
    elif step_size.dim() == 3:
        # compound steps, on multiple dims
        steps_in_directions = step_size.shape[:-1]
        output_shape = [*steps_in_directions, *latent.shape]

        edits = step_size @ repurposed_directions
    else:
        raise NotImplementedError('Cannot edit with these values')

    edit_latents = base_latent.expand(output_shape) + edits.unsqueeze(2).unsqueeze(2).expand(output_shape)

    if latent_in_w:
        # Bring back to W sapce
        base_latent, edit_latents = wplus_to_w(base_latent), edit_latents[..., 0, :]
        
    return base_latent, edit_latents

def w_to_wplus(w_latent: torch.Tensor, num_ws=18):
    return w_latent.unsqueeze(1).repeat([1, num_ws, 1])

def wplus_to_w(latents: torch.Tensor):
    """
    latents is expected to have shape (...,num_ws,512) or 
    """
    with torch.no_grad():
        _, counts = torch.unique(latents, dim=-2, return_counts=True)
    if len(counts) != 1:
        raise ValueError('input latent is not a W code, conversion from W+ is undefined')

    return latents[..., 0, :]