import sys import os import base64 import torch from PIL import Image import dnnlib import legacy def load_stylegan2(model_path, device): """ Loads the stylegan2 generator. Arguments: model_path (str): Path to model device (str): Device to load model on Returns: G (nn.Module): Stylegan generator w_avg (Tensor): The average style vector in W space """ with dnnlib.util.open_url(model_path) as f: G = legacy.load_network_pkl(f)["G_ema"] w_avg = G.mapping.w_avg.repeat(14, 1) w_avg = w_avg.to(device) G = G.to(device) return G, w_avg def tensor2im(var): """ Converts a tensor image to PIL Image. Taken from the stylegan2-ada-pytorch repo Arguments: var (Tensor): Tensor representing the input image Returns: image (PIL.Image): Image displayed """ var = (var.permute(1, 2, 0) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8) return Image.fromarray(var.cpu().numpy(), "RGB")