File size: 2,582 Bytes
73e4fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import pickle

# Load pre-trained model
with open('ffhq.pkl', 'rb') as f:
    data = pickle.load(f)

G = data['G_ema']
D = data['D']

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = G.to(device)
D = D.to(device)

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Data loading
transform = transforms.Compose([
    transforms.Resize((G.img_resolution, G.img_resolution)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = CustomDataset("/path/to/your/image_dir", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Fine-tuning setup
optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99))
optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99))
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)

num_epochs = 100
for epoch in range(num_epochs):
    for batch in dataloader:
        real_images = batch.to(device)

        # Generate fake images
        z = torch.randn([batch.size(0), G.z_dim]).to(device)
        fake_images = G(z, None)

        # Compute losses
        g_loss = -torch.mean(torch.log(D(fake_images, None)))
        d_loss_real = -torch.mean(torch.log(D(real_images, None)))
        d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None)))
        d_loss = d_loss_real + d_loss_fake

        # Update models
        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

    scheduler_g.step()
    scheduler_d.step()

    print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")

# Save the fine-tuned model
torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')