BioMike's picture
Upload 9 files
5a9c9b2 verified
raw
history blame
5.4 kB
import json
import torch
import torch.nn as nn
import os
from pathlib import Path
from typing import Optional, Union, Dict
from huggingface_hub import snapshot_download
import warnings
class ConvVAE(nn.Module):
def __init__(self, latent_size):
super(ConvVAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64)
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32)
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16)
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8)
nn.BatchNorm2d(512),
nn.ReLU()
)
self.fc_mu = nn.Linear(512 * 8 * 8, latent_size)
self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size)
self.fc2 = nn.Linear(latent_size, 512 * 8 * 8)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16)
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32)
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64)
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128)
nn.Tanh()
)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
decoded = self.decode(z)
return decoded, mu, logvar
def encode(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.fc2(z)
x = x.view(-1, 512, 8, 8)
decoded = self.decoder(x)
return decoded
@classmethod
def from_pretrained(
cls,
model_id: str,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
force_download: bool = False,
proxies: Optional[Dict] = None,
resume_download: bool = False,
local_files_only: bool = False,
token: Union[str, bool, None] = None,
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
):
"""
Load a pretrained model from a given model ID.
Args:
model_id (str): Identifier of the model to load.
revision (Optional[str]): Specific model revision to use.
cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models.
force_download (bool): Force re-download even if the model exists.
proxies (Optional[Dict]): Proxy configuration for downloads.
resume_download (bool): Resume interrupted downloads.
local_files_only (bool): Use only local files, don't download.
token (Union[str, bool, None]): Token for API authentication.
map_location (str): Device to map model to. Defaults to "cpu".
strict (bool): Enforce strict state_dict loading.
**model_kwargs: Additional keyword arguments for model initialization.
Returns:
An instance of the model loaded from the pretrained weights.
"""
model_dir = Path(model_id)
if not model_dir.exists():
model_dir = Path(
snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
)
config_file = model_dir / "config.json"
with open(config_file, 'r') as f:
config = json.load(f)
latent_size = config.get('latent_size')
if latent_size is None:
raise ValueError("The configuration file is missing the 'latent_size' key.")
model = cls(latent_size, **model_kwargs)
model_file = model_dir / "model_conv_vae_256_epoch_304.pth"
if not model_file.exists():
raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.")
state_dict = torch.load(model_file, map_location=map_location)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('_orig_mod.'):
new_state_dict[k[len('_orig_mod.'):]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=strict)
model.to(map_location)
return model