import torch from torch.optim import Adam from torch.utils.data import Dataset, DataLoader from torchvision import transforms import os from PIL import Image import pickle # Load pre-trained model with open('ffhq.pkl', 'rb') as f: data = pickle.load(f) G = data['G_ema'] D = data['D'] # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') G = G.to(device) D = D.to(device) # Custom dataset class class CustomDataset(Dataset): def __init__(self, image_dir, transform=None): self.image_dir = image_dir self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')] self.transform = transform def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = os.path.join(self.image_dir, self.image_files[idx]) image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image # Data loading transform = transforms.Compose([ transforms.Resize((G.img_resolution, G.img_resolution)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = CustomDataset("/path/to/your/image_dir", transform=transform) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # Fine-tuning setup optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99)) optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99)) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99) num_epochs = 100 for epoch in range(num_epochs): for batch in dataloader: real_images = batch.to(device) # Generate fake images z = torch.randn([batch.size(0), G.z_dim]).to(device) fake_images = G(z, None) # Compute losses g_loss = -torch.mean(torch.log(D(fake_images, None))) d_loss_real = -torch.mean(torch.log(D(real_images, None))) d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None))) d_loss = d_loss_real + d_loss_fake # Update models optimizer_g.zero_grad() g_loss.backward() optimizer_g.step() optimizer_d.zero_grad() d_loss.backward() optimizer_d.step() scheduler_g.step() scheduler_d.step() print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}") # Save the fine-tuned model torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')