# initialize vegans import torch import huggingface_hub from .generator import Generator from .discriminator import Discriminator from .utils import save_image, display_image, log, generate_noise from .dataset import TrainingSet from .config import Config class Vegans: def __init__(self, config: Config, dataset): assert config is not None self.config = config self.dataset = dataset self.device = "cuda" if torch.cuda.is_available() else "cpu" self.generator = Generator().to(self.device) self.discriminator = Discriminator().to(self.device) self.g_optim = torch.optim.Adam( self.generator.parameters(), lr=self.config.learning_rate, betas=self.config.betas, ) self.d_optim = torch.optim.Adam( self.discriminator.parameters(), lr=self.config.learning_rate, betas=self.config.betas, ) def train(self, upload=False, real=True): loss_fn = torch.nn.BCELoss() log("Started Training Loop") if real: dataset = self.dataset else: dataset = self._dummy_dataset() for epoch in range(self.num_epoch): for i, (image, label) in enumerate(dataset): # Train the generator self.g_optim.zero_grad() loss = loss_fn(0, 0) # Train the discriminator self.d_optim.zero_grad() log(f"Epoch {epoch} done. Loss is {loss.item()}") log("Finish Training") if upload: self.generator.save_pretrained("ve-gans") self.generator.push_to_hub("ve-gans") def eval(self): pass def _dummy_dataset(self): dummy_images = [] dummy_labels = [] for i in range(100): image = torch.randn( 1, self.config.channels, self.config.image_size, self.config.image_size ) label = torch.randint(0, 2, (1,)) dummy_images.append(image) dummy_labels.append(label) return [(image, label) for image, label in zip(dummy_images, dummy_labels)] def generate(self, label, save=True): # TODO noise = generate_noise() output = self.generator(noise, label) if save: save_image(output) return output def save_pretrained(self, name="ve-gans"): print("Uploading model...") self.model.save_pretrained(name) print(f"Model saved locally as '{name}'") self.model.push_to_hub(name) print(f"Model '{name}' uploaded to the Hugging Face Model Hub") def load_pretrained(self, model_id="zaibutcooler/ve-gans"): print("Loading model...") model = model.from_pretrained(model_id) print(f"Model '{model_id}' loaded successfully") return model def huggingface_login(self,token): huggingface_hub.login(token)