Spaces:
Sleeping
Sleeping
Zai
commited on
Commit
·
5f56afe
1
Parent(s):
753553a
sample and train.py implementations
Browse files- sample.py +9 -0
- train.py +11 -0
- vegans/__init__.py +1 -0
- vegans/dataset.py +13 -5
- vegans/discriminator.py +5 -2
- vegans/generator.py +1 -4
- vegans/utils.py +9 -0
- vegans/vegans.py +24 -3
sample.py
CHANGED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vegans import Vegans
|
2 |
+
|
3 |
+
vegans = Vegans
|
4 |
+
|
5 |
+
vegans.load_pretrained()
|
6 |
+
|
7 |
+
text = 'something'
|
8 |
+
|
9 |
+
vegans.generate(text)
|
train.py
CHANGED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vegans import Vegans
|
2 |
+
|
3 |
+
# To Start training loop, it's simple
|
4 |
+
|
5 |
+
vegans = Vegans()
|
6 |
+
|
7 |
+
vegans.train()
|
8 |
+
|
9 |
+
# or you can simply just
|
10 |
+
# vegans.load_pretrained()
|
11 |
+
|
vegans/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .vegans import Vegans
|
vegans/dataset.py
CHANGED
@@ -5,14 +5,22 @@ from datasets import load_dataset
|
|
5 |
|
6 |
|
7 |
class TrainingSet(Dataset):
|
8 |
-
def __init__(self,
|
9 |
super().__init__()
|
10 |
self.transforms = tran.Compose([
|
11 |
-
|
|
|
|
|
|
|
12 |
])
|
|
|
|
|
|
|
|
|
13 |
def __getitem__(self, index):
|
14 |
-
|
|
|
|
|
15 |
|
16 |
def __len__(self):
|
17 |
-
return self
|
18 |
-
|
|
|
5 |
|
6 |
|
7 |
class TrainingSet(Dataset):
|
8 |
+
def __init__(self,image_size=(128,128),gray_scale=True):
|
9 |
super().__init__()
|
10 |
self.transforms = tran.Compose([
|
11 |
+
tran.Resize(image_size),
|
12 |
+
tran.ToTensor(),
|
13 |
+
tran.Normalize([0.5],[0.5]),
|
14 |
+
tran.Grayscale() if gray_scale==True else None,
|
15 |
])
|
16 |
+
|
17 |
+
self.images = None
|
18 |
+
self.labels = None
|
19 |
+
|
20 |
def __getitem__(self, index):
|
21 |
+
image = self.transforms(self.images[index])
|
22 |
+
label = self.labels[index]
|
23 |
+
return image,label
|
24 |
|
25 |
def __len__(self):
|
26 |
+
return len(self.images)
|
|
vegans/discriminator.py
CHANGED
@@ -5,6 +5,9 @@ import torch
|
|
5 |
class Discriminator(nn.Module):
|
6 |
def __init__(self):
|
7 |
super(Discriminator, self).__init__()
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
5 |
class Discriminator(nn.Module):
|
6 |
def __init__(self):
|
7 |
super(Discriminator, self).__init__()
|
8 |
+
self.model = nn.Sequential(
|
9 |
|
10 |
+
)
|
11 |
+
|
12 |
+
def forward(self,x,y):
|
13 |
+
return x
|
vegans/generator.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
class Generator(nn.Module):
|
6 |
def __init__(self):
|
7 |
super(Generator, self).__init__()
|
8 |
-
def block():
|
9 |
pass
|
10 |
|
11 |
self.model = nn.Sequential(
|
@@ -14,6 +14,3 @@ class Generator(nn.Module):
|
|
14 |
|
15 |
def forward(self):
|
16 |
pass
|
17 |
-
|
18 |
-
def generate(self,text):
|
19 |
-
pass
|
|
|
5 |
class Generator(nn.Module):
|
6 |
def __init__(self):
|
7 |
super(Generator, self).__init__()
|
8 |
+
def block(in_feat,out_feat,norm=False):
|
9 |
pass
|
10 |
|
11 |
self.model = nn.Sequential(
|
|
|
14 |
|
15 |
def forward(self):
|
16 |
pass
|
|
|
|
|
|
vegans/utils.py
CHANGED
@@ -1 +1,10 @@
|
|
|
|
|
|
1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def display_image():
|
2 |
+
pass
|
3 |
|
4 |
+
def save_image(image,location):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def log(text):
|
8 |
+
print("#############################################\n")
|
9 |
+
print(f"{text}\n")
|
10 |
+
print("#############################################\n")
|
vegans/vegans.py
CHANGED
@@ -3,9 +3,10 @@ import torch
|
|
3 |
from generator import Generator
|
4 |
from discriminator import Discriminator
|
5 |
from discriminator import Discriminator
|
6 |
-
from utils import save_image,display_image
|
7 |
from dataset import TrainingSet
|
8 |
|
|
|
9 |
class Vegans:
|
10 |
def __init__(self,lr=0.01,b1=0.02,b2=0.02):
|
11 |
self.learning_rate = lr
|
@@ -19,15 +20,35 @@ class Vegans:
|
|
19 |
self.d_optim = torch.optim.Adam(self.discriminator.parameters(),lr=self.learning_rate,betas=self.betas)
|
20 |
|
21 |
def train(self):
|
|
|
|
|
|
|
22 |
for epoch in range(self.num_epoch):
|
23 |
for i,(image,label) in enumerate(self.dataset):
|
|
|
24 |
self.g_optim.zero_grad()
|
|
|
|
|
|
|
|
|
25 |
self.d_optim.zero_grad()
|
26 |
|
|
|
27 |
|
|
|
28 |
|
29 |
def generate(self,label):
|
30 |
-
|
|
|
|
|
31 |
|
32 |
def load_pretrained(self,checkpoint):
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from generator import Generator
|
4 |
from discriminator import Discriminator
|
5 |
from discriminator import Discriminator
|
6 |
+
from utils import save_image,display_image,log
|
7 |
from dataset import TrainingSet
|
8 |
|
9 |
+
|
10 |
class Vegans:
|
11 |
def __init__(self,lr=0.01,b1=0.02,b2=0.02):
|
12 |
self.learning_rate = lr
|
|
|
20 |
self.d_optim = torch.optim.Adam(self.discriminator.parameters(),lr=self.learning_rate,betas=self.betas)
|
21 |
|
22 |
def train(self):
|
23 |
+
loss_fn = torch.nn.BCELoss()
|
24 |
+
log("Started Training Loop")
|
25 |
+
|
26 |
for epoch in range(self.num_epoch):
|
27 |
for i,(image,label) in enumerate(self.dataset):
|
28 |
+
|
29 |
self.g_optim.zero_grad()
|
30 |
+
|
31 |
+
loss = loss_fn(0,0)
|
32 |
+
|
33 |
+
|
34 |
self.d_optim.zero_grad()
|
35 |
|
36 |
+
print(f"Epoch {epoch} done. Loss is {loss.item()}`")
|
37 |
|
38 |
+
log("Finish Training")
|
39 |
|
40 |
def generate(self,label):
|
41 |
+
# TODO
|
42 |
+
noise = 0
|
43 |
+
output = self.generator(noise,label)
|
44 |
|
45 |
def load_pretrained(self,checkpoint):
|
46 |
+
|
47 |
+
# TODO
|
48 |
+
saved_checkpoint = checkpoint
|
49 |
+
|
50 |
+
model_checkpoint = torch.load('')
|
51 |
+
self.generator.load_state_dict(model_checkpoint)
|
52 |
+
|
53 |
+
log("Successfully loaded model")
|
54 |
+
|