File size: 4,035 Bytes
bd87e2e
 
 
 
 
 
 
 
 
 
 
ae2d652
bd87e2e
 
6deedc6
bd87e2e
3927aba
6deedc6
 
 
bd87e2e
6deedc6
 
 
bd87e2e
6deedc6
 
 
bd87e2e
6deedc6
 
 
 
 
 
 
 
bd87e2e
 
6deedc6
 
bd87e2e
6deedc6
bd87e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94775f
bd87e2e
 
 
a94775f
 
 
 
 
 
 
bd87e2e
 
ae2d652
 
 
 
bd87e2e
 
 
 
 
 
 
 
eeeef15
 
 
 
 
bd87e2e
 
ae2d652
bd87e2e
 
 
ae2d652
bd87e2e
 
 
 
 
 
 
ae2d652
bd87e2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import subprocess
from pathlib import Path

import einops
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision.utils import save_image
hfapi = HfApi()

class Generator(nn.Module):
    def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(hidden_size * 8),
            nn.ReLU(True),
            # state size. (hidden_size*8) x 4 x 4
            nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            # state size. (hidden_size*4) x 8 x 8
            nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            # state size. (hidden_size*2) x 16 x 16
            nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            # state size. (hidden_size) x 32 x 32
            nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (num_channels) x 64 x 64
        )

    def forward(self, noise):
        pixel_values = self.model(noise)

        return pixel_values


@torch.no_grad()
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)

    z1 = torch.randn(rows * cols, 100, 1, 1)
    z2 = torch.randn(rows * cols, 100, 1, 1)

    zs = []
    for i in range(frames):
        alpha = i / frames
        z = (1 - alpha) * z1 + alpha * z2
        zs.append(z)

    zs += zs[::-1]  # also go in reverse order to complete loop

    frames = []
    for i, z in enumerate(zs):
        imgs = model(z)

        save_image(imgs, save_dir / f"{i:03}.png", normalize=True)
        img = Image.open(save_dir / f"{i:03}.png").convert('RGBA')
        img.putalpha(255)
        frames.append(img)
        img.save(save_dir / f"{i:03}.png")
    frames[0].save("out.gif", format="GIF", append_images=frames,
                   save_all=True, duration=100, loop=1)


def predict(model_name, choice, seed):
    model = Generator()
    weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
    model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
    torch.manual_seed(seed)

    if choice == 'interpolation':
        interpolate()
        return 'out.gif'
    else:
        z = torch.randn(64, 100, 1, 1)
        punks = model(z)
        save_image(punks, "image.png", normalize=True)
        img = Image.open(f"image.png").convert('RGBA')
        img.putalpha(255)
        img.save("image.png")
        return 'image.png'


models = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
gr.Interface(
    predict,
    inputs=[
        gr.inputs.Dropdown(models, label='Model'),
        gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
        gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
    ],
    outputs="image",
    title="Cryptopunks GAN",
    description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
    article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
    examples=[["interpolation", 100], ["interpolation", 500], ["image", 100], ["image", 500]],
).launch(cache_examples=True)