Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim import Adam
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision import transforms
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
# Load pre-trained model
|
10 |
+
with open('ffhq.pkl', 'rb') as f:
|
11 |
+
data = pickle.load(f)
|
12 |
+
|
13 |
+
G = data['G_ema']
|
14 |
+
D = data['D']
|
15 |
+
|
16 |
+
# Check if CUDA is available
|
17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
G = G.to(device)
|
19 |
+
D = D.to(device)
|
20 |
+
|
21 |
+
# Custom dataset class
|
22 |
+
class CustomDataset(Dataset):
|
23 |
+
def __init__(self, image_dir, transform=None):
|
24 |
+
self.image_dir = image_dir
|
25 |
+
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
|
26 |
+
self.transform = transform
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return len(self.image_files)
|
30 |
+
|
31 |
+
def __getitem__(self, idx):
|
32 |
+
img_path = os.path.join(self.image_dir, self.image_files[idx])
|
33 |
+
image = Image.open(img_path).convert('RGB')
|
34 |
+
if self.transform:
|
35 |
+
image = self.transform(image)
|
36 |
+
return image
|
37 |
+
|
38 |
+
# Data loading
|
39 |
+
transform = transforms.Compose([
|
40 |
+
transforms.Resize((G.img_resolution, G.img_resolution)),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
43 |
+
])
|
44 |
+
|
45 |
+
dataset = CustomDataset("/path/to/your/image_dir", transform=transform)
|
46 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
47 |
+
|
48 |
+
# Fine-tuning setup
|
49 |
+
optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99))
|
50 |
+
optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99))
|
51 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
|
52 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)
|
53 |
+
|
54 |
+
num_epochs = 100
|
55 |
+
for epoch in range(num_epochs):
|
56 |
+
for batch in dataloader:
|
57 |
+
real_images = batch.to(device)
|
58 |
+
|
59 |
+
# Generate fake images
|
60 |
+
z = torch.randn([batch.size(0), G.z_dim]).to(device)
|
61 |
+
fake_images = G(z, None)
|
62 |
+
|
63 |
+
# Compute losses
|
64 |
+
g_loss = -torch.mean(torch.log(D(fake_images, None)))
|
65 |
+
d_loss_real = -torch.mean(torch.log(D(real_images, None)))
|
66 |
+
d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None)))
|
67 |
+
d_loss = d_loss_real + d_loss_fake
|
68 |
+
|
69 |
+
# Update models
|
70 |
+
optimizer_g.zero_grad()
|
71 |
+
g_loss.backward()
|
72 |
+
optimizer_g.step()
|
73 |
+
|
74 |
+
optimizer_d.zero_grad()
|
75 |
+
d_loss.backward()
|
76 |
+
optimizer_d.step()
|
77 |
+
|
78 |
+
scheduler_g.step()
|
79 |
+
scheduler_d.step()
|
80 |
+
|
81 |
+
print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
|
82 |
+
|
83 |
+
# Save the fine-tuned model
|
84 |
+
torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')
|