Spaces:
Sleeping
Sleeping
File size: 2,984 Bytes
bbddfb0 753553a c6494da 753553a 5f56afe 753553a c6494da 753553a c6494da 753553a 5f56afe c6494da 5f56afe 753553a c6494da 5f56afe c6494da 5f56afe c6494da 5f56afe c6494da 753553a c6494da 753553a 5f56afe c6494da 753553a c6494da 753553a c6494da 5f56afe c6494da 5f56afe c6494da 5f56afe c6494da 5f56afe c6494da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# initialize vegans
import torch
import huggingface_hub
from .generator import Generator
from .discriminator import Discriminator
from .utils import save_image, display_image, log, generate_noise
from .dataset import TrainingSet
from .config import Config
class Vegans:
def __init__(self, config: Config, dataset):
assert config is not None
self.config = config
self.dataset = dataset
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.generator = Generator().to(self.device)
self.discriminator = Discriminator().to(self.device)
self.g_optim = torch.optim.Adam(
self.generator.parameters(),
lr=self.config.learning_rate,
betas=self.config.betas,
)
self.d_optim = torch.optim.Adam(
self.discriminator.parameters(),
lr=self.config.learning_rate,
betas=self.config.betas,
)
def train(self, upload=False, real=True):
loss_fn = torch.nn.BCELoss()
log("Started Training Loop")
if real:
dataset = self.dataset
else:
dataset = self._dummy_dataset()
for epoch in range(self.num_epoch):
for i, (image, label) in enumerate(dataset):
# Train the generator
self.g_optim.zero_grad()
loss = loss_fn(0, 0)
# Train the discriminator
self.d_optim.zero_grad()
log(f"Epoch {epoch} done. Loss is {loss.item()}")
log("Finish Training")
if upload:
self.generator.save_pretrained("ve-gans")
self.generator.push_to_hub("ve-gans")
def eval(self):
pass
def _dummy_dataset(self):
dummy_images = []
dummy_labels = []
for i in range(100):
image = torch.randn(
1, self.config.channels, self.config.image_size, self.config.image_size
)
label = torch.randint(0, 2, (1,))
dummy_images.append(image)
dummy_labels.append(label)
return [(image, label) for image, label in zip(dummy_images, dummy_labels)]
def generate(self, label, save=True):
# TODO
noise = generate_noise()
output = self.generator(noise, label)
if save:
save_image(output)
return output
def save_pretrained(self, name="ve-gans"):
print("Uploading model...")
self.model.save_pretrained(name)
print(f"Model saved locally as '{name}'")
self.model.push_to_hub(name)
print(f"Model '{name}' uploaded to the Hugging Face Model Hub")
def load_pretrained(self, model_id="zaibutcooler/ve-gans"):
print("Loading model...")
model = model.from_pretrained(model_id)
print(f"Model '{model_id}' loaded successfully")
return model
def huggingface_login(self,token):
huggingface_hub.login(token) |