File size: 997 Bytes
c9224f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch.utils.checkpoint import checkpoint  
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf
from taming.models.vqgan import GumbelVQ

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator:
    def __init__(self, config_path, device=device):
        self.config_path = config_path
        self.device = device

    def load_models(self):
        # Load configuration
        config = OmegaConf.load(self.config_path)
        # Extract parameters specific to GumbelVQ
        vq_params = config.model.params
        # Initialize the GumbelVQ models
        model_vaq = GumbelVQ(
            ddconfig=vq_params.ddconfig,
            lossconfig=vq_params.lossconfig,
            n_embed=vq_params.n_embed,
            embed_dim=vq_params.embed_dim,
            kl_weight=vq_params.kl_weight,
            temperature_scheduler_config=vq_params.temperature_scheduler_config,
        ).to(self.device)

        return model_vaq