Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import random | |
from scipy.ndimage import gaussian_filter, map_coordinates # Add this line | |
import PIL | |
class ResidualConvBlock(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(ResidualConvBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
self.in1 = nn.InstanceNorm2d(out_channels) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
self.in2 = nn.InstanceNorm2d(out_channels) | |
self.relu = nn.LeakyReLU(inplace=True) | |
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None | |
def forward(self, x): | |
residual = x | |
out = self.relu(self.in1(self.conv1(x))) | |
out = self.in2(self.conv2(out)) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
return self.relu(out) | |
class AttentionGate(nn.Module): | |
def __init__(self, F_g, F_l, F_int): | |
super(AttentionGate, self).__init__() | |
self.W_g = nn.Sequential( | |
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.InstanceNorm2d(F_int) | |
) | |
self.W_x = nn.Sequential( | |
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.InstanceNorm2d(F_int) | |
) | |
self.psi = nn.Sequential( | |
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.InstanceNorm2d(1), | |
nn.Sigmoid() | |
) | |
self.relu = nn.LeakyReLU(inplace=True) | |
def forward(self, g, x): | |
g1 = self.W_g(g) | |
x1 = self.W_x(x) | |
psi = self.relu(g1 + x1) | |
psi = self.psi(psi) | |
return x * psi | |
class EnhancedUNet(nn.Module): | |
def __init__(self, n_channels, n_classes): | |
super(EnhancedUNet, self).__init__() | |
self.n_channels = n_channels | |
self.n_classes = n_classes | |
self.inc = ResidualConvBlock(n_channels, 64) | |
self.down1 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(64, 128)) | |
self.down2 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(128, 256)) | |
self.down3 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(256, 512)) | |
self.down4 = nn.Sequential(nn.MaxPool2d(2), ResidualConvBlock(512, 1024)) | |
self.dilation = nn.Sequential( | |
nn.Conv2d(1024, 1024, kernel_size=3, padding=2, dilation=2), | |
nn.InstanceNorm2d(1024), | |
nn.LeakyReLU(inplace=True), | |
nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4), | |
nn.InstanceNorm2d(1024), | |
nn.LeakyReLU(inplace=True) | |
) | |
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) | |
self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256) | |
self.up_conv4 = ResidualConvBlock(1024, 512) | |
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128) | |
self.up_conv3 = ResidualConvBlock(512, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64) | |
self.up_conv2 = ResidualConvBlock(256, 128) | |
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32) | |
self.up_conv1 = ResidualConvBlock(128, 64) | |
self.outc = nn.Conv2d(64, n_classes, kernel_size=1) | |
self.dropout = nn.Dropout(0.5) | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x2 = self.dropout(x2) | |
x3 = self.down2(x2) | |
x3 = self.dropout(x3) | |
x4 = self.down3(x3) | |
x4 = self.dropout(x4) | |
x5 = self.down4(x4) | |
x5 = self.dilation(x5) | |
x5 = self.dropout(x5) | |
x = self.up4(x5) | |
x4 = self.att4(g=x, x=x4) | |
x = torch.cat([x4, x], dim=1) | |
x = self.up_conv4(x) | |
x = self.dropout(x) | |
x = self.up3(x) | |
x3 = self.att3(g=x, x=x3) | |
x = torch.cat([x3, x], dim=1) | |
x = self.up_conv3(x) | |
x = self.dropout(x) | |
x = self.up2(x) | |
x2 = self.att2(g=x, x=x2) | |
x = torch.cat([x2, x], dim=1) | |
x = self.up_conv2(x) | |
x = self.dropout(x) | |
x = self.up1(x) | |
x1 = self.att1(g=x, x=x1) | |
x = torch.cat([x1, x], dim=1) | |
x = self.up_conv1(x) | |
logits = self.outc(x) | |
return logits | |
class MoS2Dataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.root_dir = root_dir | |
self.transform = transform | |
self.images_dir = os.path.join(root_dir, 'images') | |
self.labels_dir = os.path.join(root_dir, 'labels') | |
self.image_files = [] | |
for f in sorted(os.listdir(self.images_dir)): | |
if f.endswith('.png'): | |
try: | |
Image.open(os.path.join(self.images_dir, f)).verify() | |
self.image_files.append(f) | |
except: | |
print(f"Skipping unreadable image: {f}") | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
img_name = self.image_files[idx] | |
img_path = os.path.join(self.images_dir, img_name) | |
if not os.path.exists(img_path): | |
print(f"Image file does not exist: {img_path}") | |
return None, None | |
label_name = f"image_{img_name.split('_')[1].replace('.png', '.npy')}" | |
label_path = os.path.join(self.labels_dir, label_name) | |
try: | |
image = np.array(Image.open(img_path).convert('L'), dtype=np.float32) / 255.0 | |
label = np.load(label_path).astype(np.int64) | |
except (PIL.UnidentifiedImageError, FileNotFoundError, IOError) as e: | |
print(f"Error loading image {img_path}: {str(e)}") | |
return None, None # Or handle this case appropriately | |
if self.transform: | |
image, label = self.transform(image, label) | |
image = torch.from_numpy(image).float().unsqueeze(0) | |
label = torch.from_numpy(label).long() | |
return image, label | |
class AugmentationTransform: | |
def __init__(self): | |
self.aug_functions = [ | |
self.random_brightness_contrast, | |
self.random_gamma, | |
self.random_noise, | |
self.random_elastic_deform | |
] | |
def __call__(self, image, label): | |
for aug_func in self.aug_functions: | |
if random.random() < 0.5: # 50% chance to apply each augmentation | |
image, label = aug_func(image, label) | |
return image.astype(np.float32), label # Ensure float32 | |
def random_brightness_contrast(self, image, label): | |
brightness = random.uniform(0.7, 1.3) | |
contrast = random.uniform(0.7, 1.3) | |
image = np.clip(brightness * image + contrast * (image - 0.5) + 0.5, 0, 1) | |
return image, label | |
def random_gamma(self, image, label): | |
gamma = random.uniform(0.7, 1.3) | |
image = np.power(image, gamma) | |
return image, label | |
def random_noise(self, image, label): | |
noise = np.random.normal(0, 0.05, image.shape) | |
image = np.clip(image + noise, 0, 1) | |
return image, label | |
def random_elastic_deform(self, image, label): | |
alpha = random.uniform(10, 20) | |
sigma = random.uniform(3, 5) | |
shape = image.shape | |
dx = np.random.rand(*shape) * 2 - 1 | |
dy = np.random.rand(*shape) * 2 - 1 | |
dx = gaussian_filter(dx, sigma, mode="constant", cval=0) * alpha | |
dy = gaussian_filter(dy, sigma, mode="constant", cval=0) * alpha | |
x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) | |
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)) | |
image = map_coordinates(image, indices, order=1).reshape(shape) | |
label = map_coordinates(label, indices, order=0).reshape(shape) | |
return image, label | |
def focal_loss(output, target, alpha=0.25, gamma=2): | |
ce_loss = nn.CrossEntropyLoss(reduction='none')(output, target) | |
pt = torch.exp(-ce_loss) | |
focal_loss = alpha * (1-pt)**gamma * ce_loss | |
return focal_loss.mean() | |
def dice_loss(output, target, smooth=1e-5): | |
output = torch.softmax(output, dim=1) | |
num_classes = output.shape[1] | |
dice_sum = 0 | |
for c in range(num_classes): | |
pred_class = output[:, c, :, :] | |
target_class = (target == c).float() | |
intersection = (pred_class * target_class).sum() | |
union = pred_class.sum() + target_class.sum() | |
dice = (2. * intersection + smooth) / (union + smooth) | |
dice_sum += dice | |
return 1 - dice_sum / num_classes | |
def combined_loss(output, target): | |
fl = focal_loss(output, target) | |
dl = dice_loss(output, target) | |
return 0.5 * fl + 0.5 * dl | |
def iou_score(output, target): | |
smooth = 1e-5 | |
output = torch.argmax(output, dim=1) | |
intersection = (output & target).float().sum((1, 2)) | |
union = (output | target).float().sum((1, 2)) | |
iou = (intersection + smooth) / (union + smooth) | |
return iou.mean() | |
def pixel_accuracy(output, target): | |
output = torch.argmax(output, dim=1) | |
correct = torch.eq(output, target).int() | |
accuracy = float(correct.sum()) / float(correct.numel()) | |
return accuracy | |
def train_one_epoch(model, dataloader, optimizer, criterion, device): | |
model.train() | |
total_loss = 0 | |
total_iou = 0 | |
total_accuracy = 0 | |
pbar = tqdm(dataloader, desc='Training') | |
for images, labels in pbar: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
total_iou += iou_score(outputs, labels) | |
total_accuracy += pixel_accuracy(outputs, labels) | |
pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), | |
'IoU': total_iou / (pbar.n + 1), | |
'Accuracy': total_accuracy / (pbar.n + 1)}) | |
return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) | |
def validate(model, dataloader, criterion, device): | |
model.eval() | |
total_loss = 0 | |
total_iou = 0 | |
total_accuracy = 0 | |
with torch.no_grad(): | |
pbar = tqdm(dataloader, desc='Validation') | |
for images, labels in pbar: | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
total_loss += loss.item() | |
total_iou += iou_score(outputs, labels) | |
total_accuracy += pixel_accuracy(outputs, labels) | |
pbar.set_postfix({'Loss': total_loss / (pbar.n + 1), | |
'IoU': total_iou / (pbar.n + 1), | |
'Accuracy': total_accuracy / (pbar.n + 1)}) | |
return total_loss / len(dataloader), total_iou / len(dataloader), total_accuracy / len(dataloader) | |
def main(): | |
# Hyperparameters | |
num_classes = 4 | |
batch_size = 64 | |
num_epochs = 100 | |
learning_rate = 1e-4 | |
weight_decay = 1e-5 | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Create datasets and data loaders | |
transform = AugmentationTransform() | |
# dataset = MoS2Dataset('MoS2_dataset_advanced_v2', transform=transform) | |
dataset = MoS2Dataset('dataset_with_noise_npy') | |
train_size = int(0.8 * len(dataset)) | |
val_size = len(dataset) - train_size | |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
# Create model | |
model = EnhancedUNet(n_channels=1, n_classes=num_classes).to(device) | |
# Loss and optimizer | |
criterion = combined_loss | |
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True) | |
# Create directory for saving models and visualizations | |
save_dir = 'enhanced_training_results' | |
os.makedirs(save_dir, exist_ok=True) | |
# Training loop | |
best_val_iou = 0.0 | |
for epoch in range(1, num_epochs + 1): | |
print(f"Epoch {epoch}/{num_epochs}") | |
train_loss, train_iou, train_accuracy = train_one_epoch(model, train_loader, optimizer, criterion, device) | |
val_loss, val_iou, val_accuracy = validate(model, val_loader, criterion, device) | |
print(f"Train - Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, Accuracy: {train_accuracy:.4f}") | |
print(f"Val - Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, Accuracy: {val_accuracy:.4f}") | |
scheduler.step(val_iou) | |
if val_iou > best_val_iou: | |
best_val_iou = val_iou | |
torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth')) | |
print(f"New best model saved with IoU: {best_val_iou:.4f}") | |
# Save checkpoint | |
torch.save({ | |
'epoch': epoch, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict': scheduler.state_dict(), | |
'best_val_iou': best_val_iou, | |
}, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth')) | |
# Visualize predictions every 5 epochs | |
visualize_prediction(model, val_loader, device, epoch, save_dir) | |
print("Training completed!") | |
def visualize_prediction(model, val_loader, device, epoch, save_dir): | |
model.eval() | |
images, labels = next(iter(val_loader)) | |
images, labels = images.to(device), labels.to(device) | |
with torch.no_grad(): | |
outputs = model(images) | |
images = images.cpu().numpy() | |
labels = labels.cpu().numpy() | |
predictions = torch.argmax(outputs, dim=1).cpu().numpy() | |
fig, axs = plt.subplots(2, 3, figsize=(15, 10)) | |
axs[0, 0].imshow(images[0, 0], cmap='gray') | |
axs[0, 0].set_title('Input Image') | |
axs[0, 1].imshow(labels[0], cmap='viridis') | |
axs[0, 1].set_title('True Label') | |
axs[0, 2].imshow(predictions[0], cmap='viridis') | |
axs[0, 2].set_title('Prediction') | |
axs[1, 0].imshow(images[1, 0], cmap='gray') | |
axs[1, 1].imshow(labels[1], cmap='viridis') | |
axs[1, 2].imshow(predictions[1], cmap='viridis') | |
plt.tight_layout() | |
plt.savefig(os.path.join(save_dir, f'prediction_epoch_{epoch}.png')) | |
plt.close() | |
if __name__ == "__main__": | |
main() | |