Spaces:
Sleeping
Sleeping
import json | |
import torch | |
import torch.nn as nn | |
import os | |
from pathlib import Path | |
from typing import Optional, Union, Dict | |
from huggingface_hub import snapshot_download | |
import warnings | |
class ConvVAE(nn.Module): | |
def __init__(self, latent_size): | |
super(ConvVAE, self).__init__() | |
# Encoder | |
self.encoder = nn.Sequential( | |
nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64) | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32) | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16) | |
nn.BatchNorm2d(256), | |
nn.ReLU(), | |
nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8) | |
nn.BatchNorm2d(512), | |
nn.ReLU() | |
) | |
self.fc_mu = nn.Linear(512 * 8 * 8, latent_size) | |
self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size) | |
self.fc2 = nn.Linear(latent_size, 512 * 8 * 8) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16) | |
nn.BatchNorm2d(256), | |
nn.ReLU(), | |
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32) | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64) | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128) | |
nn.Tanh() | |
) | |
def forward(self, x): | |
mu, logvar = self.encode(x) | |
z = self.reparameterize(mu, logvar) | |
decoded = self.decode(z) | |
return decoded, mu, logvar | |
def encode(self, x): | |
x = self.encoder(x) | |
x = x.view(x.size(0), -1) | |
mu = self.fc_mu(x) | |
logvar = self.fc_logvar(x) | |
return mu, logvar | |
def reparameterize(self, mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return mu + eps * std | |
def decode(self, z): | |
x = self.fc2(z) | |
x = x.view(-1, 512, 8, 8) | |
decoded = self.decoder(x) | |
return decoded | |
def from_pretrained( | |
cls, | |
model_id: str, | |
revision: Optional[str] = None, | |
cache_dir: Optional[Union[str, Path]] = None, | |
force_download: bool = False, | |
proxies: Optional[Dict] = None, | |
resume_download: bool = False, | |
local_files_only: bool = False, | |
token: Union[str, bool, None] = None, | |
map_location: str = "cpu", | |
strict: bool = False, | |
**model_kwargs, | |
): | |
""" | |
Load a pretrained model from a given model ID. | |
Args: | |
model_id (str): Identifier of the model to load. | |
revision (Optional[str]): Specific model revision to use. | |
cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models. | |
force_download (bool): Force re-download even if the model exists. | |
proxies (Optional[Dict]): Proxy configuration for downloads. | |
resume_download (bool): Resume interrupted downloads. | |
local_files_only (bool): Use only local files, don't download. | |
token (Union[str, bool, None]): Token for API authentication. | |
map_location (str): Device to map model to. Defaults to "cpu". | |
strict (bool): Enforce strict state_dict loading. | |
**model_kwargs: Additional keyword arguments for model initialization. | |
Returns: | |
An instance of the model loaded from the pretrained weights. | |
""" | |
model_dir = Path(model_id) | |
if not model_dir.exists(): | |
model_dir = Path( | |
snapshot_download( | |
repo_id=model_id, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
) | |
config_file = model_dir / "config.json" | |
with open(config_file, 'r') as f: | |
config = json.load(f) | |
latent_size = config.get('latent_size') | |
if latent_size is None: | |
raise ValueError("The configuration file is missing the 'latent_size' key.") | |
model = cls(latent_size, **model_kwargs) | |
model_file = model_dir / "model_conv_vae_256_epoch_304.pth" | |
if not model_file.exists(): | |
raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.") | |
state_dict = torch.load(model_file, map_location=map_location) | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith('_orig_mod.'): | |
new_state_dict[k[len('_orig_mod.'):]] = v | |
else: | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict, strict=strict) | |
model.to(map_location) | |
return model | |