BioMike commited on
Commit
5a9c9b2
·
verified ·
1 Parent(s): f71074a

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from vae import vae
3
+ from morphing import morphing
4
+ from model import ConvVAE
5
+
6
+ model = ConvVAE.from_pretrained(
7
+ model_id="BioMike/classical_portrait_vae",
8
+ cache_dir="./model_cache",
9
+ map_location="cpu",
10
+ strict=True
11
+ ).eval()
12
+
13
+
14
+ demo = gr.TabbedInterface([vae, morphing],
15
+ ["Image to Portrait", "Image to Image (Morphing)"]
16
+ title="CLassical Portraits VAE",
17
+ theme=gr.themes.Base())
18
+
19
+ demo.queue()
20
+ demo.launch(debug=True, share=True)
example_images/image1.jpg ADDED
example_images/image2.png ADDED
example_images/image3.jpg ADDED
example_images/image4.jpg ADDED
model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional, Union, Dict
7
+ from huggingface_hub import snapshot_download
8
+ import warnings
9
+
10
+ class ConvVAE(nn.Module):
11
+ def __init__(self, latent_size):
12
+ super(ConvVAE, self).__init__()
13
+
14
+ # Encoder
15
+ self.encoder = nn.Sequential(
16
+ nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64)
17
+ nn.BatchNorm2d(64),
18
+ nn.ReLU(),
19
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32)
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16)
23
+ nn.BatchNorm2d(256),
24
+ nn.ReLU(),
25
+ nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8)
26
+ nn.BatchNorm2d(512),
27
+ nn.ReLU()
28
+ )
29
+
30
+ self.fc_mu = nn.Linear(512 * 8 * 8, latent_size)
31
+ self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size)
32
+
33
+ self.fc2 = nn.Linear(latent_size, 512 * 8 * 8)
34
+
35
+ self.decoder = nn.Sequential(
36
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16)
37
+ nn.BatchNorm2d(256),
38
+ nn.ReLU(),
39
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32)
40
+ nn.BatchNorm2d(128),
41
+ nn.ReLU(),
42
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64)
43
+ nn.BatchNorm2d(64),
44
+ nn.ReLU(),
45
+ nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128)
46
+ nn.Tanh()
47
+ )
48
+
49
+ def forward(self, x):
50
+ mu, logvar = self.encode(x)
51
+ z = self.reparameterize(mu, logvar)
52
+ decoded = self.decode(z)
53
+ return decoded, mu, logvar
54
+
55
+ def encode(self, x):
56
+ x = self.encoder(x)
57
+ x = x.view(x.size(0), -1)
58
+ mu = self.fc_mu(x)
59
+ logvar = self.fc_logvar(x)
60
+ return mu, logvar
61
+
62
+ def reparameterize(self, mu, logvar):
63
+ std = torch.exp(0.5 * logvar)
64
+ eps = torch.randn_like(std)
65
+ return mu + eps * std
66
+
67
+ def decode(self, z):
68
+ x = self.fc2(z)
69
+ x = x.view(-1, 512, 8, 8)
70
+ decoded = self.decoder(x)
71
+ return decoded
72
+
73
+ @classmethod
74
+ def from_pretrained(
75
+ cls,
76
+ model_id: str,
77
+ revision: Optional[str] = None,
78
+ cache_dir: Optional[Union[str, Path]] = None,
79
+ force_download: bool = False,
80
+ proxies: Optional[Dict] = None,
81
+ resume_download: bool = False,
82
+ local_files_only: bool = False,
83
+ token: Union[str, bool, None] = None,
84
+ map_location: str = "cpu",
85
+ strict: bool = False,
86
+ **model_kwargs,
87
+ ):
88
+ """
89
+ Load a pretrained model from a given model ID.
90
+
91
+ Args:
92
+ model_id (str): Identifier of the model to load.
93
+ revision (Optional[str]): Specific model revision to use.
94
+ cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models.
95
+ force_download (bool): Force re-download even if the model exists.
96
+ proxies (Optional[Dict]): Proxy configuration for downloads.
97
+ resume_download (bool): Resume interrupted downloads.
98
+ local_files_only (bool): Use only local files, don't download.
99
+ token (Union[str, bool, None]): Token for API authentication.
100
+ map_location (str): Device to map model to. Defaults to "cpu".
101
+ strict (bool): Enforce strict state_dict loading.
102
+ **model_kwargs: Additional keyword arguments for model initialization.
103
+
104
+ Returns:
105
+ An instance of the model loaded from the pretrained weights.
106
+ """
107
+ model_dir = Path(model_id)
108
+ if not model_dir.exists():
109
+ model_dir = Path(
110
+ snapshot_download(
111
+ repo_id=model_id,
112
+ revision=revision,
113
+ cache_dir=cache_dir,
114
+ force_download=force_download,
115
+ proxies=proxies,
116
+ resume_download=resume_download,
117
+ token=token,
118
+ local_files_only=local_files_only,
119
+ )
120
+ )
121
+
122
+ config_file = model_dir / "config.json"
123
+ with open(config_file, 'r') as f:
124
+ config = json.load(f)
125
+
126
+ latent_size = config.get('latent_size')
127
+ if latent_size is None:
128
+ raise ValueError("The configuration file is missing the 'latent_size' key.")
129
+
130
+ model = cls(latent_size, **model_kwargs)
131
+
132
+ model_file = model_dir / "model_conv_vae_256_epoch_304.pth"
133
+ if not model_file.exists():
134
+ raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.")
135
+
136
+ state_dict = torch.load(model_file, map_location=map_location)
137
+
138
+ new_state_dict = {}
139
+ for k, v in state_dict.items():
140
+ if k.startswith('_orig_mod.'):
141
+ new_state_dict[k[len('_orig_mod.'):]] = v
142
+ else:
143
+ new_state_dict[k] = v
144
+
145
+ model.load_state_dict(new_state_dict, strict=strict)
146
+ model.to(map_location)
147
+
148
+ return model
morphing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFilter
5
+ import gradio as gr
6
+ import numpy as np
7
+ import os
8
+ import uuid
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ transform = transforms.Compose([
13
+ transforms.Resize((128, 128)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize((0.5,), (0.5,))
16
+ ])
17
+
18
+ resize_transform = transforms.Resize((512, 512))
19
+
20
+ def load_image(image):
21
+ image = Image.fromarray(image).convert('RGB')
22
+ image = transform(image)
23
+ return image.unsqueeze(0).to(device)
24
+
25
+ def interpolate_vectors(v1, v2, num_steps):
26
+ return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
27
+
28
+ def infer_and_interpolate(image1, image2, num_interpolations=24):
29
+ image1 = load_image(image1)
30
+ image2 = load_image(image2)
31
+
32
+ with torch.no_grad():
33
+ mu1, logvar1 = model.encode(image1)
34
+ mu2, logvar2 = model.encode(image2)
35
+ interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations)
36
+ decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors]
37
+
38
+ return decoded_images
39
+
40
+ def create_gif(decoded_images, duration=200, apply_blur=False):
41
+ reversed_images = decoded_images[::-1]
42
+ all_images = decoded_images + reversed_images
43
+
44
+ pil_images = []
45
+ for img in all_images:
46
+ img = (img - img.min()) / (img.max() - img.min())
47
+ img = (img * 255).byte()
48
+ pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB")
49
+ pil_img = resize_transform(pil_img)
50
+ if apply_blur:
51
+ pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1))
52
+ pil_images.append(pil_img)
53
+
54
+ gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif"
55
+ pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0)
56
+
57
+ return gif_filename
58
+
59
+ def create_morphing_gif(image1, image2, num_interpolations=24, duration=200):
60
+ decoded_images = infer_and_interpolate(image1, image2, num_interpolations)
61
+ gif_path = create_gif(decoded_images, duration)
62
+
63
+ return gif_path
64
+
65
+ examples = [
66
+ ["example_images/image1.jpg", "example_images/image2.png", 24, 200],
67
+ ["example_images/image3.jpg", "example_images/image4.jpg", 30, 150],
68
+ ]
69
+
70
+ with gr.Blocks() as morphing:
71
+ with gr.Column():
72
+ with gr.Column():
73
+ num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations")
74
+ duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)")
75
+ generate_button = gr.Button("Generate Morphing GIF")
76
+ output_gif = gr.Image(label="Morphing GIF")
77
+ with gr.Row():
78
+ image1 = gr.Image(label="Upload first image", type="numpy")
79
+ image2 = gr.Image(label="Upload second image", type="numpy")
80
+
81
+ generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif)
82
+
83
+ gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ PIL
4
+ gradio
5
+ uuid
6
+ pathlib
7
+ huggingface_hub
vae.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ transform1 = transforms.Compose([
11
+ transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.5,), (0.5,))
14
+ ])
15
+
16
+ transform2 = transforms.Compose([
17
+ transforms.Resize((512, 512)) # Resize the image to 512x512 for display
18
+ ])
19
+
20
+ def load_image(image):
21
+ image = Image.fromarray(image).convert('RGB')
22
+ image = transform1(image)
23
+ return image.unsqueeze(0).to(device)
24
+
25
+ def infer_image(image, noise_level):
26
+ image = load_image(image)
27
+ with torch.no_grad():
28
+ mu, logvar = model.encode(image)
29
+ std = torch.exp(0.5 * logvar)
30
+ eps = torch.randn_like(std) * noise_level
31
+ z = mu + eps * std
32
+ decoded_image = model.decode(z)
33
+
34
+ decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
35
+ decoded_image = np.clip(decoded_image, 0, 1)
36
+
37
+ decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
38
+ decoded_image = transform2(decoded_image)
39
+ return np.array(decoded_image)
40
+
41
+ examples = [
42
+ ["example_images/image1.jpg", 0.1],
43
+ ["example_images/image2.png", 0.5],
44
+ ["example_images/image3.jpg", 1.0],
45
+ ]
46
+
47
+ with gr.Blocks() as vae:
48
+ noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
49
+ with gr.Row():
50
+ with gr.Column():
51
+ input_image = gr.Image(label="Upload an image", type="numpy")
52
+ with gr.Column():
53
+ output_image = gr.Image(label="Reconstructed Image")
54
+
55
+ input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
56
+ noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
57
+
58
+ gr.Examples(examples=examples, inputs=[input_image, noise_slider])