Spaces:
Runtime error
Runtime error
File size: 2,721 Bytes
4d6b877 7bf4347 b260763 4d6b877 b260763 a813ad5 b260763 4d6b877 b260763 4d6b877 b260763 4d6b877 b260763 4d6b877 b260763 09ef29c 4d6b877 09ef29c 4d6b877 b260763 7bf4347 4d6b877 8279921 7bf4347 4d6b877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import torch
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
from utils.constants import VALID_CHOICES, ENABLE_GPU, MODEL_NAME, OUTPUT_LIST, description, title
from utils.image_manip import tensor_to_pil, concat_images
def get_generator(model_name):
if model_name == 'stylegan_ffhq':
generator = StyleGANGenerator(model_name)
elif model_name == 'stylegan2_ffhq':
generator = StyleGAN2Generator(model_name)
else:
raise ValueError('Model name not recognized')
if ENABLE_GPU:
generator = generator.cuda()
return generator
generator = get_generator(MODEL_NAME)
boundaries = {
boundary:np.squeeze(np.load(open(os.path.join('boundaries', MODEL_NAME, 'boundary_%s.npy' % boundary), 'rb')))
for boundary in VALID_CHOICES
}
@torch.no_grad()
def inference(seed, coef, nb_images, list_choices):
global generator, boundaries
np.random.seed(seed)
latent_codes = generator.easy_sample(nb_images)
if ENABLE_GPU:
latent_codes = latent_codes.cuda()
generator = generator.cuda()
generated_images = generator.easy_synthesize(latent_codes)
generated_images = tensor_to_pil(generated_images)
new_latent_codes = latent_codes.copy()
for i, _ in enumerate(generated_images):
for choice in list_choices:
new_latent_codes[i, :] += boundaries[choice]*coef
modified_generated_images = generator.easy_synthesize(new_latent_codes)
modified_generated_images = tensor_to_pil(modified_generated_images)
concatenated_output = concat_images(generated_images, modified_generated_images)
return concatenated_output
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=644,
label="Random seed to use for the generation"
),
gr.inputs.Slider(
minimum=-3,
maximum=3,
step=0.1,
default=0,
label="Modification coefficient",
),
gr.inputs.Slider(
minimum=1,
maximum=10,
step=1,
default=2,
label="Number of images to generate",
),
gr.inputs.CheckboxGroup(
VALID_CHOICES,
default=[],
type="value",
label="Select attributes to modify",
optional=False
)
],
outputs=OUTPUT_LIST,
layout="horizontal",
theme="peach",
description=description,
title=title,
)
iface.launch() |