Spaces:
Runtime error
Runtime error
File size: 1,474 Bytes
560a1b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Based on a script from https://github.com/rosinality/stylegan2-pytorch
# ==========================================================================================
#
# Adobe’s modifications are Copyright 2023 Adobe Research. All rights reserved.
# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
# LICENSE.md.
#
# ==========================================================================================
import argparse
import numpy as np
import torch
from pathlib import Path
import dnnlib
import legacy
def factorize(G):
modulate = {
k: v
for k, v in G.named_parameters()
if ('b4' in k or "torgb" not in k) and ("affine" in k and "weight" in k)
}
weight_mat = []
for k, v in modulate.items():
weight_mat.append(v)
W = torch.cat(weight_mat, 0)
eigvec = torch.svd(W).V
return eigvec
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract factor/eigenvectors of latent spaces using closed form factorization"
)
parser.add_argument("--out", type=str, requited=True, help="path to output file")
parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
args = parser.parse_args()
device = 'cuda'
with dnnlib.util.open_url(args.ckpt) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
eigvec = factorize(G)
torch.save(eigvec, args.out)
|