BioMike commited on
Commit
2c480a0
·
verified ·
1 Parent(s): 5d66e0c

Upload 16 files

Browse files
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from encoding import get_interface as encoding_page
3
+ from generation import get_interface as generation_page
4
+ from interpolation import get_interface as interpolation_page
5
+
6
+ with gr.Blocks() as demo:
7
+ with gr.Tab("Encode & Reconstruct"):
8
+ encoding_page()
9
+ with gr.Tab("Generate from Noise"):
10
+ generation_page()
11
+ with gr.Tab("Interpolate"):
12
+ interpolation_page()
13
+
14
+ demo.launch()
encoding.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ from model import model
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ resize_input = transforms.Resize((32, 32))
11
+ to_tensor = transforms.ToTensor()
12
+
13
+ def reconstruct_image(image):
14
+ image = Image.fromarray(image).convert('RGB')
15
+ image_32 = resize_input(image)
16
+ image_tensor = to_tensor(image_32).unsqueeze(0).to(device)
17
+ with torch.no_grad():
18
+ mu, _ = model.encode(image_tensor)
19
+ recon = model.decode(mu)
20
+ recon_np = recon.squeeze(0).permute(1, 2, 0).cpu().numpy()
21
+ recon_img = Image.fromarray((recon_np * 255).astype(np.uint8)).resize((512, 512))
22
+ orig_resized = image_32.resize((512, 512))
23
+ return orig_resized, recon_img
24
+
25
+ def get_interface():
26
+ with gr.Blocks() as iface:
27
+ gr.Markdown("## Encoding & Reconstruction")
28
+ with gr.Row():
29
+ input_image = gr.Image(label="Input (Downsampled to 32x32)", type="numpy")
30
+ output_image = gr.Image(label="Reconstructed", type="pil")
31
+ run_button = gr.Button("Run Reconstruction")
32
+
33
+ run_button.click(fn=reconstruct_image, inputs=input_image, outputs=[input_image, output_image])
34
+
35
+ examples = [
36
+ ["example_images/image1.jpg"],
37
+ ["example_images/image2.jpg"],
38
+ ["example_images/image3.jpg"],
39
+ ["example_images/image10.jpg"],
40
+ ["example_images/image4.jpg"],
41
+ ["example_images/image5.jpg"],
42
+ ["example_images/image6.jpg"],
43
+ ["example_images/image7.jpg"],
44
+ ["example_images/image8.jpg"],
45
+ ]
46
+
47
+ gr.Examples(
48
+ examples=examples,
49
+ inputs=[input_image],
50
+ )
51
+ return iface
example_images/image0.jpg ADDED
example_images/image1.jpg ADDED
example_images/image10.jpg ADDED
example_images/image2.jpg ADDED
example_images/image3.jpg ADDED
example_images/image4.jpg ADDED
example_images/image5.jpg ADDED
example_images/image6.jpg ADDED
example_images/image7.jpg ADDED
example_images/image8.jpg ADDED
generation.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from model import model
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ latent_dim = model.config.latent_dim
9
+
10
+ def generate_from_noise():
11
+ z = torch.randn(1, latent_dim).to(device)
12
+ with torch.no_grad():
13
+ generated = model.decode(z)
14
+ gen_img = generated.squeeze(0).permute(1, 2, 0).cpu().numpy()
15
+ gen_pil = Image.fromarray((gen_img * 255).astype("uint8")).resize((512, 512))
16
+ return gen_pil
17
+
18
+ def get_interface():
19
+ with gr.Blocks() as iface:
20
+ gr.Markdown("## Generate from Random Noise")
21
+ generate_button = gr.Button("Generate Image")
22
+ output_image = gr.Image(label="Generated Image", type="pil")
23
+ generate_button.click(fn=generate_from_noise, inputs=[], outputs=output_image)
24
+
25
+ examples = [[]]
26
+
27
+ gr.Examples(
28
+ examples=examples,
29
+ inputs=[],
30
+ outputs=output_image,
31
+ fn=generate_from_noise,
32
+ cache_examples=False
33
+ )
34
+ return iface
interpolation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ from model import model
7
+ import tempfile
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ transform = transforms.Compose([
12
+ transforms.Resize((32, 32)),
13
+ transforms.ToTensor()
14
+ ])
15
+
16
+ resize_output = transforms.Resize((512, 512))
17
+
18
+ def interpolate_vectors(v1, v2, num_steps):
19
+ return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
20
+
21
+ def to_pil(img_tensor):
22
+ img = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
23
+ img = (img * 255).clip(0, 255).astype(np.uint8)
24
+ return Image.fromarray(img)
25
+
26
+ def interpolate_images_gif(img1, img2, num_interpolations=10, duration=100):
27
+ img1 = Image.fromarray(img1).convert('RGB')
28
+ img2 = Image.fromarray(img2).convert('RGB')
29
+ img1_tensor = transform(img1).unsqueeze(0).to(device)
30
+ img2_tensor = transform(img2).unsqueeze(0).to(device)
31
+
32
+ with torch.no_grad():
33
+ mu1, _ = model.encode(img1_tensor)
34
+ mu2, _ = model.encode(img2_tensor)
35
+ interpolated = interpolate_vectors(mu1, mu2, num_interpolations)
36
+ decoded_images = []
37
+ for z in interpolated:
38
+ out = model.decode(z)
39
+ img = to_pil(out)
40
+ img_resized = resize_output(img)
41
+ decoded_images.append(img_resized)
42
+
43
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
44
+ decoded_images[0].save(
45
+ tmp_file.name,
46
+ save_all=True,
47
+ append_images=decoded_images[1:],
48
+ duration=duration,
49
+ loop=0
50
+ )
51
+ return tmp_file.name
52
+
53
+ def get_interface():
54
+ with gr.Blocks() as iface:
55
+ gr.Markdown("## Latent Space Interpolation (GIF Output)")
56
+ with gr.Row():
57
+ img1 = gr.Image(label="First Image", type="numpy")
58
+ img2 = gr.Image(label="Second Image", type="numpy")
59
+ slider_steps = gr.Slider(5, 30, value=10, step=1, label="Number of Interpolations")
60
+ slider_duration = gr.Slider(50, 500, value=100, step=10, label="Duration per Frame (ms)")
61
+ output_gif = gr.Image(label="Interpolation GIF")
62
+ run_button = gr.Button("Interpolate")
63
+
64
+ run_button.click(
65
+ fn=interpolate_images_gif,
66
+ inputs=[img1, img2, slider_steps, slider_duration],
67
+ outputs=output_gif
68
+ )
69
+
70
+ examples = [
71
+ ["example_images/image1.jpg", "example_images/image2.jpg", 10, 100],
72
+ ["example_images/image3.jpg", "example_images/image4.jpg", 15, 150],
73
+ ["example_images/image5.jpg", "example_images/image6.jpg", 20, 200],
74
+ ["example_images/image7.jpg", "example_images/image8.jpg", 25, 250],
75
+ ]
76
+
77
+ gr.Examples(
78
+ examples=examples,
79
+ inputs=[img1, img2, slider_steps, slider_duration],
80
+ outputs=output_gif,
81
+ fn=interpolate_images_gif,
82
+ cache_examples=False
83
+ )
84
+ return iface
model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+
6
+
7
+ class BaseVAE(nn.Module):
8
+ def __init__(self, latent_dim=16):
9
+ super(BaseVAE, self).__init__()
10
+ self.latent_dim = latent_dim
11
+
12
+ self.encoder = nn.Sequential(
13
+ nn.Conv2d(3, 32, 4, 2, 1), # 32x32 -> 16x16
14
+ nn.BatchNorm2d(32),
15
+ nn.ReLU(),
16
+ nn.Conv2d(32, 64, 4, 2, 1), # 16x16 -> 8x8
17
+ nn.BatchNorm2d(64),
18
+ nn.ReLU(),
19
+ nn.Conv2d(64, 128, 4, 2, 1), # 8x8 -> 4x4
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.Flatten()
23
+ )
24
+ self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
25
+ self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)
26
+
27
+ self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4)
28
+ self.decoder = nn.Sequential(
29
+ nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 -> 8x8
30
+ nn.BatchNorm2d(64),
31
+ nn.ReLU(),
32
+ nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 -> 16x16
33
+ nn.BatchNorm2d(32),
34
+ nn.ReLU(),
35
+ nn.ConvTranspose2d(32, 3, 4, 2, 1), # 16x16 -> 32x32
36
+ nn.Sigmoid()
37
+ )
38
+
39
+ def encode(self, x):
40
+ x = self.encoder(x)
41
+ mu = self.fc_mu(x)
42
+ logvar = self.fc_logvar(x)
43
+ return mu, logvar
44
+
45
+ def reparameterize(self, mu, logvar):
46
+ std = torch.exp(0.5 * logvar)
47
+ eps = torch.randn_like(std)
48
+ return mu + eps * std
49
+
50
+ def decode(self, z):
51
+ x = self.decoder_input(z)
52
+ x = x.view(-1, 128, 4, 4)
53
+ return self.decoder(x)
54
+
55
+ def forward(self, x):
56
+ mu, logvar = self.encode(x)
57
+ z = self.reparameterize(mu, logvar)
58
+ recon = self.decode(z)
59
+ return recon, mu, logvar
60
+
61
+ class VAEConfig(PretrainedConfig):
62
+ model_type = "vae"
63
+
64
+ def __init__(self, latent_dim=16, **kwargs):
65
+ super().__init__(**kwargs)
66
+ self.latent_dim = latent_dim
67
+
68
+ class VAEModel(PreTrainedModel):
69
+ config_class = VAEConfig
70
+
71
+ def __init__(self, config):
72
+ super().__init__(config)
73
+ self.vae = BaseVAE(latent_dim=config.latent_dim)
74
+ self.post_init()
75
+
76
+ def forward(self, x):
77
+ return self.vae(x)
78
+
79
+ def encode(self, x):
80
+ return self.vae.encode(x)
81
+
82
+ def decode(self, z):
83
+ return self.vae.decode(z)
84
+
85
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device)
87
+ model.eval()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ transformers
5
+ datasets
6
+ huggingface_hub
7
+ pillow
8
+ numpy