donut-AE / autoencoder.py
Mikus
Upload 4 files
d1ef8ee verified
raw
history blame
2.84 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from model import aeModel
class ImageDataset(Dataset):
def __init__(self, folder_path):
self.folder_path = folder_path
self.image_files = [f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
self.transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.folder_path, self.image_files[idx])
image = Image.open(img_path).convert('RGB')
image = self.transform(image)
return image
def train(model, dataloader, num_epochs, device):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
batch = batch.to(device)
output = model(batch)
loss = criterion(output, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
def visualize_results(model, dataloader, device):
model.eval()
with torch.no_grad():
images = next(iter(dataloader))
images = images.to(device)
reconstructions = model(images)
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i in range(5):
axes[0, i].imshow(images[i].cpu().permute(1, 2, 0))
axes[0, i].axis('off')
axes[1, i].imshow(reconstructions[i].cpu().permute(1, 2, 0))
axes[1, i].axis('off')
plt.tight_layout()
plt.show()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if ur not using nvidia for inference, are you a freak who uses directml :eww:
print(f"Using device: {device}")
dataset = ImageDataset('dataset/images/')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = aeModel().to(device)
#model.load_state_dict(torch.load('autoencoder_250.pth'))
num_epochs = 250
train(model, dataloader, num_epochs, device)
visualize_results(model, dataloader, device)
torch.save(model.state_dict(), 'autoencoder.pth')
if __name__ == "__main__":
main()