my-stylegan-model / train.py
edemana's picture
Update train.py
73e4fea verified
raw
history blame
2.58 kB
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import pickle
# Load pre-trained model
with open('ffhq.pkl', 'rb') as f:
data = pickle.load(f)
G = data['G_ema']
D = data['D']
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = G.to(device)
D = D.to(device)
# Custom dataset class
class CustomDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# Data loading
transform = transforms.Compose([
transforms.Resize((G.img_resolution, G.img_resolution)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CustomDataset("/path/to/your/image_dir", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Fine-tuning setup
optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99))
optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99))
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)
num_epochs = 100
for epoch in range(num_epochs):
for batch in dataloader:
real_images = batch.to(device)
# Generate fake images
z = torch.randn([batch.size(0), G.z_dim]).to(device)
fake_images = G(z, None)
# Compute losses
g_loss = -torch.mean(torch.log(D(fake_images, None)))
d_loss_real = -torch.mean(torch.log(D(real_images, None)))
d_loss_fake = -torch.mean(torch.log(1 - D(fake_images, None)))
d_loss = d_loss_real + d_loss_fake
# Update models
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()
scheduler_g.step()
scheduler_d.step()
print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")
# Save the fine-tuned model
torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')