Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from torchvision.utils import save_image | |
import torch.nn as nn | |
import torch | |
import torchvision.datasets as datasets | |
import torchvision.transforms as T | |
from torch.utils.data import DataLoader, ConcatDataset | |
import torchvision.utils as vutils | |
import random | |
LATENT_VECTOR_DIM = 8 # latent vector dimension | |
class Generator_128(nn.Module): | |
def __init__(self, GPU_COUNT): | |
super(Generator_128, self).__init__() | |
self.GPU_COUNT = GPU_COUNT | |
self.main = nn.Sequential( | |
# LATENT_VECTOR_DIM x 1 x 1 | |
nn.ConvTranspose2d(LATENT_VECTOR_DIM, 1024, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(1024), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(128,64, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(64,3, 4, 2, 1, bias=False), | |
nn.Tanh() | |
# 128 x 128 x 3 | |
) | |
def forward(self, input): | |
return self.main(input) | |
trained_gen = Generator_128(0) | |
trained_gen.load_state_dict(torch.load("generator_epoch_3000v1-5.h5",map_location=torch.device('cpu'))) | |
def predict(seed, pokemon_count): | |
torch.manual_seed(seed) | |
z = torch.randn(pokemon_count, LATENT_VECTOR_DIM, 1, 1) | |
punks = trained_gen(z) | |
save_image(punks, "pokemon.png", normalize=True) | |
return 'pokemon.png' | |
gr.Interface( | |
predict, | |
inputs=[ | |
gr.Slider(0, 1000, label='Seed', default=42), | |
gr.Slider(1, 8, label='Number of pokemon', step=1, default=10), | |
], | |
outputs="image", | |
).launch() | |