edemana commited on
Commit
73e4fea
·
verified ·
1 Parent(s): 6713576

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +84 -0
train.py CHANGED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Adam
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ import os
6
+ from PIL import Image
7
+ import pickle
8
+
9
+ # Load pre-trained model
10
+ with open('ffhq.pkl', 'rb') as f:
11
+ data = pickle.load(f)
12
+
13
+ G = data['G_ema']
14
+ D = data['D']
15
+
16
+ # Check if CUDA is available
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ G = G.to(device)
19
+ D = D.to(device)
20
+
21
+ # Custom dataset class
22
+ class CustomDataset(Dataset):
23
+ def __init__(self, image_dir, transform=None):
24
+ self.image_dir = image_dir
25
+ self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
26
+ self.transform = transform
27
+
28
+ def __len__(self):
29
+ return len(self.image_files)
30
+
31
+ def __getitem__(self, idx):
32
+ img_path = os.path.join(self.image_dir, self.image_files[idx])
33
+ image = Image.open(img_path).convert('RGB')
34
+ if self.transform:
35
+ image = self.transform(image)
36
+ return image
37
+
38
+ # Data loading
39
+ transform = transforms.Compose([
40
+ transforms.Resize((G.img_resolution, G.img_resolution)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43
+ ])
44
+
45
+ dataset = CustomDataset("/path/to/your/image_dir", transform=transform)
46
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
47
+
48
+ # Fine-tuning setup
49
+ optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99))
50
+ optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99))
51
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
52
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)
53
+
54
+ num_epochs = 100
55
+ for epoch in range(num_epochs):
56
+ for batch in dataloader:
57
+ real_images = batch.to(device)
58
+
59
+ # Generate fake images
60
+ z = torch.randn([batch.size(0), G.z_dim]).to(device)
61
+ fake_images = G(z, None)
62
+
63
+ # Compute losses
64
+ g_loss = -torch.mean(torch.log(D(fake_images, None)))
65
+ d_loss_real = -torch.mean(torch.log(D(real_images, None)))
66
+ d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None)))
67
+ d_loss = d_loss_real + d_loss_fake
68
+
69
+ # Update models
70
+ optimizer_g.zero_grad()
71
+ g_loss.backward()
72
+ optimizer_g.step()
73
+
74
+ optimizer_d.zero_grad()
75
+ d_loss.backward()
76
+ optimizer_d.step()
77
+
78
+ scheduler_g.step()
79
+ scheduler_d.step()
80
+
81
+ print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
82
+
83
+ # Save the fine-tuned model
84
+ torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')