echen01
working demo
2fec875
raw
history blame
1.01 kB
import sys
import os
import base64
import torch
from PIL import Image
import dnnlib
import legacy
def load_stylegan2(model_path, device):
"""
Loads the stylegan2 generator.
Arguments:
model_path (str): Path to model
device (str): Device to load model on
Returns:
G (nn.Module): Stylegan generator
w_avg (Tensor): The average style vector in W space
"""
with dnnlib.util.open_url(model_path) as f:
G = legacy.load_network_pkl(f)["G_ema"]
w_avg = G.mapping.w_avg.repeat(14, 1)
w_avg = w_avg.to(device)
G = G.to(device)
return G, w_avg
def tensor2im(var):
"""
Converts a tensor image to PIL Image. Taken from the stylegan2-ada-pytorch repo
Arguments:
var (Tensor): Tensor representing the input image
Returns:
image (PIL.Image): Image displayed
"""
var = (var.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8)
return Image.fromarray(var.cpu().numpy(), "RGB")