umairahmad1789's picture
initial commit
efd5df3 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
from scipy.ndimage import gaussian_filter, map_coordinates # Add this line
import PIL
class ResidualConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.in1 = nn.InstanceNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.in2 = nn.InstanceNorm2d(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
if self.downsample:
residual = self.downsample(x)
out += residual
return self.relu(out)
class AttentionGate(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionGate, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class EnhancedUNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(EnhancedUNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.inc = ResidualConvBlock(n_channels, 64)
self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128))
self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256))
self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512))
self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024))
self.dilation = nn.Sequential(
nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2),
nn.InstanceNorm2d(1024),
nn.LeakyReLU(inplace=True),
nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4),
nn.InstanceNorm2d(1024),
nn.LeakyReLU(inplace=True)
)
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256)
self.up_conv4 = ResidualConvBlock(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128)
self.up_conv3 = ResidualConvBlock(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
self.up_conv2 = ResidualConvBlock(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32)
self.up_conv1 = ResidualConvBlock(128, 64)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x2 = self.dropout(x2)
x3 = self.down2(x2)
x3 = self.dropout(x3)
x4 = self.down3(x3)
x4 = self.dropout(x4)
x5 = self.down4(x4)
x5 = self.dilation(x5)
x5 = self.dropout(x5)
x = self.up4(x5)
x4 = self.att4(g=x, x=x4)
x = torch.cat([x4, x], dim=1)
x = self.up_conv4(x)
x = self.dropout(x)
x = self.up3(x)
x3 = self.att3(g=x, x=x3)
x = torch.cat([x3, x], dim=1)
x = self.up_conv3(x)
x = self.dropout(x)
x = self.up2(x)
x2 = self.att2(g=x, x=x2)
x = torch.cat([x2, x], dim=1)
x = self.up_conv2(x)
x = self.dropout(x)
x = self.up1(x)
x1 = self.att1(g=x, x=x1)
x = torch.cat([x1, x], dim=1)
x = self.up_conv1(x)
logits = self.outc(x)
return logits
class MoS2Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images_dir = os.path.join(root_dir, 'images')
self.labels_dir = os.path.join(root_dir, 'labels')
self.image_files = []
for f in sorted(os.listdir(self.images_dir)):
if f.endswith('.png'):
try:
Image.open(os.path.join(self.images_dir, f)).verify()
self.image_files.append(f)
except:
print(f"Skipping unreadable image: {f}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = self.image_files[idx]
img_path = os.path.join(self.images_dir, img_name)
if not os.path.exists(img_path):
print(f"Image file does not exist: {img_path}")
return None, None
label_name = f"image_{img_name.split('_')[1].replace('.png', '.npy')}"
label_path = os.path.join(self.labels_dir, label_name)
try:
image = np.array(Image.open(img_path).convert('L'), dtype=np.float32) / 255.0
label = np.load(label_path).astype(np.int64)
except (PIL.UnidentifiedImageError, FileNotFoundError, IOError) as e:
print(f"Error loading image {img_path}: {str(e)}")
return None, None # Or handle this case appropriately
if self.transform:
image, label = self.transform(image, label)
image = torch.from_numpy(image).float().unsqueeze(0)
label = torch.from_numpy(label).long()
return image, label
class AugmentationTransform:
def __init__(self):
self.aug_functions = [
self.random_brightness_contrast,
self.random_gamma,
self.random_noise,
self.random_elastic_deform
]
def __call__(self, image, label):
for aug_func in self.aug_functions:
if random.random() < 0.5: # 50% chance to apply each augmentation
image, label = aug_func(image, label)
return image.astype(np.float32), label # Ensure float32
def random_brightness_contrast(self, image, label):
brightness = random.uniform(0.7, 1.3)
contrast = random.uniform(0.7, 1.3)
image = np.clip(brightness * image + contrast * (image - 0.5) + 0.5, 0, 1)
return image, label
def random_gamma(self, image, label):
gamma = random.uniform(0.7, 1.3)
image = np.power(image, gamma)
return image, label
def random_noise(self, image, label):
noise = np.random.normal(0, 0.05, image.shape)
image = np.clip(image + noise, 0, 1)
return image, label
def random_elastic_deform(self, image, label):
alpha = random.uniform(10, 20)
sigma = random.uniform(3, 5)
shape = image.shape
dx = np.random.rand(*shape) * 2 - 1
dy = np.random.rand(*shape) * 2 - 1
dx = gaussian_filter(dx, sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter(dy, sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
image = map_coordinates(image, indices, order=1).reshape(shape)
label = map_coordinates(label, indices, order=0).reshape(shape)
return image, label
def focal_loss(output, target, alpha=0.25, gamma=2):
ce_loss = nn.CrossEntropyLoss(reduction='none')(output, target)
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1-pt)**gamma * ce_loss
return focal_loss.mean()
def dice_loss(output, target, smooth=1e-5):
output = torch.softmax(output, dim=1)
num_classes = output.shape[1]
dice_sum = 0
for c in range(num_classes):
pred_class = output[:, c, :, :]
target_class = (target == c).float()
intersection = (pred_class * target_class).sum()
union = pred_class.sum() + target_class.sum()
dice = (2. * intersection + smooth) / (union + smooth)
dice_sum += dice
return 1 - dice_sum / num_classes
def combined_loss(output, target):
fl = focal_loss(output, target)
dl = dice_loss(output, target)
return 0.5 * fl + 0.5 * dl
def iou_score(output, target):
smooth = 1e-5
output = torch.argmax(output, dim=1)
intersection = (output & target).float().sum((1, 2))
union = (output | target).float().sum((1, 2))
iou = (intersection + smooth) / (union + smooth)
return iou.mean()
def pixel_accuracy(output, target):
output = torch.argmax(output, dim=1)
correct = torch.eq(output, target).int()
accuracy = float(correct.sum()) / float(correct.numel())
return accuracy
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0
total_iou = 0
total_accuracy = 0
pbar = tqdm(dataloader, desc='Training')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_iou += iou_score(outputs, labels)
total_accuracy += pixel_accuracy(outputs, labels)
pbar.set_postfix({'Loss': total_loss / (pbar.n + 1),
'IoU': total_iou / (pbar.n + 1),
'Accuracy': total_accuracy / (pbar.n + 1)})
return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader)
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
total_iou = 0
total_accuracy = 0
with torch.no_grad():
pbar = tqdm(dataloader, desc='Validation')
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
total_iou += iou_score(outputs, labels)
total_accuracy += pixel_accuracy(outputs, labels)
pbar.set_postfix({'Loss': total_loss / (pbar.n + 1),
'IoU': total_iou / (pbar.n + 1),
'Accuracy': total_accuracy / (pbar.n + 1)})
return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader)
def main():
# Hyperparameters
num_classes = 4
batch_size = 64
num_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create datasets and data loaders
transform = AugmentationTransform()
# dataset = MoS2Dataset('MoS2_dataset_advanced_v2', transform=transform)
dataset = MoS2Dataset('dataset_with_noise_npy')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# Create model
model = EnhancedUNet(n_channels=1, n_classes=num_classes).to(device)
# Loss and optimizer
criterion = combined_loss
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True)
# Create directory for saving models and visualizations
save_dir = 'enhanced_training_results'
os.makedirs(save_dir, exist_ok=True)
# Training loop
best_val_iou = 0.0
for epoch in range(1, num_epochs + 1):
print(f"Epoch {epoch}/{num_epochs}")
train_loss, train_iou, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_iou, val_accuracy = validate(model, val_loader, criterion, device)
print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Accuracy: {train_accuracy:.4f}")
print(f"Val - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Accuracy: {val_accuracy:.4f}")
scheduler.step(val_iou)
if val_iou > best_val_iou:
best_val_iou = val_iou
torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
print(f"New best model saved with IoU: {best_val_iou:.4f}")
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_iou': best_val_iou,
}, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth'))
# Visualize predictions every 5 epochs
visualize_prediction(model, val_loader, device, epoch, save_dir)
print("Training completed!")
def visualize_prediction(model, val_loader, device, epoch, save_dir):
model.eval()
images, labels = next(iter(val_loader))
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
outputs = model(images)
images = images.cpu().numpy()
labels = labels.cpu().numpy()
predictions = torch.argmax(outputs, dim=1).cpu().numpy()
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
axs[0, 0].imshow(images[0, 0], cmap='gray')
axs[0, 0].set_title('Input Image')
axs[0, 1].imshow(labels[0], cmap='viridis')
axs[0, 1].set_title('True Label')
axs[0, 2].imshow(predictions[0], cmap='viridis')
axs[0, 2].set_title('Prediction')
axs[1, 0].imshow(images[1, 0], cmap='gray')
axs[1, 1].imshow(labels[1], cmap='viridis')
axs[1, 2].imshow(predictions[1], cmap='viridis')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, f'prediction_epoch_{epoch}.png'))
plt.close()
if __name__ == "__main__":
main()