File size: 3,649 Bytes
21c7368
7e5db72
 
 
 
 
 
 
 
21c7368
8c44627
98e44a5
 
aade1d5
21c7368
98e44a5
7e5db72
 
 
98e44a5
7e5db72
 
 
 
98e44a5
7e5db72
 
 
 
98e44a5
7e5db72
 
 
 
 
 
 
98e44a5
7e5db72
 
 
b9f3167
7e5db72
8c44627
98e44a5
 
 
 
7e5db72
 
98e44a5
7e5db72
 
 
 
 
 
 
 
 
 
 
 
 
98e44a5
7e5db72
 
 
 
 
 
 
9edd29b
7e5db72
 
 
 
 
 
 
 
 
 
 
 
98e44a5
 
7e5db72
c167ae6
9edd29b
7e5db72
98e44a5
9edd29b
 
8c44627
 
9edd29b
 
 
7e5db72
 
9edd29b
 
7e5db72
 
9edd29b
98e44a5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import zipfile
import numpy as np
import torch
from transformers import ViTForImageClassification, AdamW
import nibabel as nib
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import streamlit as st

# Function to extract zip file
def extract_zip(zip_file, extract_to):
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

# Preprocess images
def preprocess_image(image_path):
    ext = os.path.splitext(image_path)[-1].lower()

    if ext in ['.nii', '.nii.gz']:
        nii_image = nib.load(image_path)
        image_data = nii_image.get_fdata()
        image_tensor = torch.tensor(image_data).float()
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(0)

    elif ext in ['.jpg', '.jpeg']:
        img = Image.open(image_path).convert('RGB').resize((224, 224))
        img_np = np.array(img)
        image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()

    else:
        raise ValueError(f"Unsupported format: {ext}")

    image_tensor /= 255.0  # Normalize to [0, 1]
    return image_tensor

# Prepare dataset
def prepare_dataset(extracted_folder):
    image_paths = []
    labels = []
    for disease_folder in ['alzheimers_dataset', 'parkinsons_dataset', 'ms']:
        folder_path = os.path.join(extracted_folder, disease_folder)
        label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'ms': 2}[disease_folder]
        for img_file in os.listdir(folder_path):
            if img_file.endswith(('.nii', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(folder_path, img_file))
                labels.append(label)
    return image_paths, labels

# Custom Dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = preprocess_image(self.image_paths[idx])
        label = self.labels[idx]
        return image, label

# Training function
def fine_tune_model(train_loader):
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
    model.train()
    optimizer = AdamW(model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    for epoch in range(10):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(pixel_values=images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    return running_loss / len(train_loader)

# Streamlit UI for Fine-tuning
st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")

zip_file = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip"

if st.button("Start Training"):
    extraction_dir = "extracted_files"
    os.makedirs(extraction_dir, exist_ok=True)

    # Extract the zip file
    extract_zip(zip_file, extraction_dir)

    # Prepare dataset
    image_paths, labels = prepare_dataset(extraction_dir)
    dataset = CustomImageDataset(image_paths, labels)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Fine-tune the model
    final_loss = fine_tune_model(train_loader)
    st.write(f"Training Complete with Final Loss: {final_loss}")