File size: 2,984 Bytes
bbddfb0
753553a
c6494da
 
 
 
 
 
 
753553a
5f56afe
753553a
c6494da
 
 
 
 
753553a
 
c6494da
 
 
 
 
 
 
 
 
 
 
 
753553a
5f56afe
 
c6494da
 
 
 
5f56afe
753553a
c6494da
5f56afe
c6494da
 
5f56afe
c6494da
5f56afe
c6494da
753553a
 
c6494da
753553a
5f56afe
c6494da
 
 
753553a
c6494da
 
753553a
c6494da
 
 
 
 
 
 
 
 
 
 
5f56afe
c6494da
5f56afe
c6494da
 
 
 
 
5f56afe
c6494da
 
 
 
 
 
5f56afe
c6494da
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# initialize vegans
import torch
import huggingface_hub

from .generator import Generator
from .discriminator import Discriminator
from .utils import save_image, display_image, log, generate_noise
from .dataset import TrainingSet
from .config import Config


class Vegans:
    def __init__(self, config: Config, dataset):
        assert config is not None
        self.config = config
        self.dataset = dataset
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.generator = Generator().to(self.device)
        self.discriminator = Discriminator().to(self.device)
        self.g_optim = torch.optim.Adam(
            self.generator.parameters(),
            lr=self.config.learning_rate,
            betas=self.config.betas,
        )
        self.d_optim = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.config.learning_rate,
            betas=self.config.betas,
        )

    def train(self, upload=False, real=True):

        loss_fn = torch.nn.BCELoss()
        log("Started Training Loop")
        if real:
            dataset = self.dataset
        else:
            dataset = self._dummy_dataset()

        for epoch in range(self.num_epoch):
            for i, (image, label) in enumerate(dataset):

                # Train the generator
                self.g_optim.zero_grad()

                loss = loss_fn(0, 0)

                # Train the discriminator
                self.d_optim.zero_grad()

            log(f"Epoch {epoch} done. Loss is {loss.item()}")

        log("Finish Training")
        if upload:
            self.generator.save_pretrained("ve-gans")
            self.generator.push_to_hub("ve-gans")

    def eval(self):
        pass

    def _dummy_dataset(self):
        dummy_images = []
        dummy_labels = []
        for i in range(100):
            image = torch.randn(
                1, self.config.channels, self.config.image_size, self.config.image_size
            )
            label = torch.randint(0, 2, (1,))
            dummy_images.append(image)
            dummy_labels.append(label)
        return [(image, label) for image, label in zip(dummy_images, dummy_labels)]

    def generate(self, label, save=True):
        # TODO
        noise = generate_noise()
        output = self.generator(noise, label)
        if save:
            save_image(output)
        return output

    def save_pretrained(self, name="ve-gans"):
        print("Uploading model...")
        self.model.save_pretrained(name)
        print(f"Model saved locally as '{name}'")
        self.model.push_to_hub(name)
        print(f"Model '{name}' uploaded to the Hugging Face Model Hub")

    def load_pretrained(self, model_id="zaibutcooler/ve-gans"):
        print("Loading model...")
        model = model.from_pretrained(model_id)
        print(f"Model '{model_id}' loaded successfully")
        return model
    
    def huggingface_login(self,token):
        huggingface_hub.login(token)