Zai commited on
Commit
5f56afe
·
1 Parent(s): 753553a

sample and train.py implementations

Browse files
sample.py CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from vegans import Vegans
2
+
3
+ vegans = Vegans
4
+
5
+ vegans.load_pretrained()
6
+
7
+ text = 'something'
8
+
9
+ vegans.generate(text)
train.py CHANGED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vegans import Vegans
2
+
3
+ # To Start training loop, it's simple
4
+
5
+ vegans = Vegans()
6
+
7
+ vegans.train()
8
+
9
+ # or you can simply just
10
+ # vegans.load_pretrained()
11
+
vegans/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .vegans import Vegans
vegans/dataset.py CHANGED
@@ -5,14 +5,22 @@ from datasets import load_dataset
5
 
6
 
7
  class TrainingSet(Dataset):
8
- def __init__(self,dataset_name='',image_size=(128,128)):
9
  super().__init__()
10
  self.transforms = tran.Compose([
11
-
 
 
 
12
  ])
 
 
 
 
13
  def __getitem__(self, index):
14
- return self
 
 
15
 
16
  def __len__(self):
17
- return self
18
-
 
5
 
6
 
7
  class TrainingSet(Dataset):
8
+ def __init__(self,image_size=(128,128),gray_scale=True):
9
  super().__init__()
10
  self.transforms = tran.Compose([
11
+ tran.Resize(image_size),
12
+ tran.ToTensor(),
13
+ tran.Normalize([0.5],[0.5]),
14
+ tran.Grayscale() if gray_scale==True else None,
15
  ])
16
+
17
+ self.images = None
18
+ self.labels = None
19
+
20
  def __getitem__(self, index):
21
+ image = self.transforms(self.images[index])
22
+ label = self.labels[index]
23
+ return image,label
24
 
25
  def __len__(self):
26
+ return len(self.images)
 
vegans/discriminator.py CHANGED
@@ -5,6 +5,9 @@ import torch
5
  class Discriminator(nn.Module):
6
  def __init__(self):
7
  super(Discriminator, self).__init__()
 
8
 
9
- def forward(self):
10
- pass
 
 
 
5
  class Discriminator(nn.Module):
6
  def __init__(self):
7
  super(Discriminator, self).__init__()
8
+ self.model = nn.Sequential(
9
 
10
+ )
11
+
12
+ def forward(self,x,y):
13
+ return x
vegans/generator.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  class Generator(nn.Module):
6
  def __init__(self):
7
  super(Generator, self).__init__()
8
- def block():
9
  pass
10
 
11
  self.model = nn.Sequential(
@@ -14,6 +14,3 @@ class Generator(nn.Module):
14
 
15
  def forward(self):
16
  pass
17
-
18
- def generate(self,text):
19
- pass
 
5
  class Generator(nn.Module):
6
  def __init__(self):
7
  super(Generator, self).__init__()
8
+ def block(in_feat,out_feat,norm=False):
9
  pass
10
 
11
  self.model = nn.Sequential(
 
14
 
15
  def forward(self):
16
  pass
 
 
 
vegans/utils.py CHANGED
@@ -1 +1,10 @@
 
 
1
 
 
 
 
 
 
 
 
 
1
+ def display_image():
2
+ pass
3
 
4
+ def save_image(image,location):
5
+ pass
6
+
7
+ def log(text):
8
+ print("#############################################\n")
9
+ print(f"{text}\n")
10
+ print("#############################################\n")
vegans/vegans.py CHANGED
@@ -3,9 +3,10 @@ import torch
3
  from generator import Generator
4
  from discriminator import Discriminator
5
  from discriminator import Discriminator
6
- from utils import save_image,display_image
7
  from dataset import TrainingSet
8
 
 
9
  class Vegans:
10
  def __init__(self,lr=0.01,b1=0.02,b2=0.02):
11
  self.learning_rate = lr
@@ -19,15 +20,35 @@ class Vegans:
19
  self.d_optim = torch.optim.Adam(self.discriminator.parameters(),lr=self.learning_rate,betas=self.betas)
20
 
21
  def train(self):
 
 
 
22
  for epoch in range(self.num_epoch):
23
  for i,(image,label) in enumerate(self.dataset):
 
24
  self.g_optim.zero_grad()
 
 
 
 
25
  self.d_optim.zero_grad()
26
 
 
27
 
 
28
 
29
  def generate(self,label):
30
- pass
 
 
31
 
32
  def load_pretrained(self,checkpoint):
33
- pass
 
 
 
 
 
 
 
 
 
3
  from generator import Generator
4
  from discriminator import Discriminator
5
  from discriminator import Discriminator
6
+ from utils import save_image,display_image,log
7
  from dataset import TrainingSet
8
 
9
+
10
  class Vegans:
11
  def __init__(self,lr=0.01,b1=0.02,b2=0.02):
12
  self.learning_rate = lr
 
20
  self.d_optim = torch.optim.Adam(self.discriminator.parameters(),lr=self.learning_rate,betas=self.betas)
21
 
22
  def train(self):
23
+ loss_fn = torch.nn.BCELoss()
24
+ log("Started Training Loop")
25
+
26
  for epoch in range(self.num_epoch):
27
  for i,(image,label) in enumerate(self.dataset):
28
+
29
  self.g_optim.zero_grad()
30
+
31
+ loss = loss_fn(0,0)
32
+
33
+
34
  self.d_optim.zero_grad()
35
 
36
+ print(f"Epoch {epoch} done. Loss is {loss.item()}`")
37
 
38
+ log("Finish Training")
39
 
40
  def generate(self,label):
41
+ # TODO
42
+ noise = 0
43
+ output = self.generator(noise,label)
44
 
45
  def load_pretrained(self,checkpoint):
46
+
47
+ # TODO
48
+ saved_checkpoint = checkpoint
49
+
50
+ model_checkpoint = torch.load('')
51
+ self.generator.load_state_dict(model_checkpoint)
52
+
53
+ log("Successfully loaded model")
54
+