Tanusree88 commited on
Commit
88720a5
·
verified ·
1 Parent(s): 0f02599

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -137
app.py CHANGED
@@ -1,119 +1,23 @@
1
  import os
2
  import zipfile
3
- import requests
4
  import numpy as np
5
  import torch
6
- from transformers import ViTForImageClassification, AdamW
7
- import nibabel as nib
8
- from PIL import Image
9
- from torch.utils.data import Dataset, DataLoader
10
- import streamlit as st
11
-
12
- # Function to download the zip file from the URL
13
- def download_zip(url, save_path):
14
- response = requests.get(url)
15
- with open(save_path, 'wb') as f:
16
- f.write(response.content)
17
-
18
- # Function to extract zip file
19
- def extract_zip(zip_file, extract_to):
20
- with zipfile.ZipFile(zip_file, 'r') as zip_ref:
21
- zip_ref.extractall(extract_to)
22
-
23
- # Preprocess images
24
- def preprocess_image(image_path):
25
- ext = os.path.splitext(image_path)[-1].lower()
26
-
27
- if ext in ['.nii', '.nii.gz']:
28
- nii_image = nib.load(image_path)
29
- image_data = nii_image.get_fdata()
30
- image_tensor = torch.tensor(image_data).float()
31
- if len(image_tensor.shape) == 3:
32
- image_tensor = image_tensor.unsqueeze(0)
33
-
34
- elif ext in ['.jpg', '.jpeg']:
35
- img = Image.open(image_path).convert('RGB').resize((224, 224))
36
- img_np = np.array(img)
37
- image_tensor = torch.tensor(img_np).permute(2, 0, 1).float()
38
-
39
- else:
40
- raise ValueError(f"Unsupported format: {ext}")
41
-
42
- image_tensor /= 255.0 # Normalize to [0, 1]
43
- return image_tensor
44
-
45
- # Prepare dataset
46
- def prepare_dataset(extracted_folder):
47
- image_paths = []
48
- labels = []
49
- for disease_folder in ['alzheimers_dataset', 'parkinsons_dataset', 'MSjpg']:
50
- folder_path = os.path.join(extracted_folder, disease_folder)
51
- label = {'alzheimers_dataset': 0, 'parkinsons_dataset': 1, 'MSjpg': 2}[disease_folder]
52
- for img_file in os.listdir(folder_path):
53
- if img_file.endswith(('.nii', '.jpg', '.jpeg')):
54
- image_paths.append(os.path.join(folder_path, img_file))
55
- labels.append(label)
56
- return image_paths, labels
57
-
58
- # Custom Dataset class
59
- class CustomImageDataset(Dataset):
60
- def __init__(self, image_paths, labels):
61
- self.image_paths = image_paths
62
- self.labels = labels
63
-
64
- def __len__(self):
65
- return len(self.image_paths)
66
-
67
- def __getitem__(self, idx):
68
- image = preprocess_image(self.image_paths[idx])
69
- label = self.labels[idx]
70
- return image, label
71
-
72
- # Training function
73
- def fine_tune_model(train_loader):
74
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=3)
75
- model.train()
76
- optimizer = AdamW(model.parameters(), lr=1e-4)
77
- criterion = torch.nn.CrossEntropyLoss()
78
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
- model.to(device)
80
-
81
- for epoch in range(10):
82
- running_loss = 0.0
83
- for images, labels in train_loader:
84
- images, labels = images.to(device), labels.to(device)
85
- optimizer.zero_grad()
86
- outputs = model(pixel_values=images).logits
87
- loss = criterion(outputs, labels)
88
- loss.backward()
89
- optimizer.step()
90
- running_loss += loss.item()
91
- return running_loss / len(train_loader)
92
-
93
- # Streamlit UI for Fine-tuning
94
- st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
95
-
96
- zip_file_url = import os
97
- import zipfile
98
  import requests
99
- import numpy as np
100
- import torch
101
  from transformers import ViTForImageClassification, AdamW
102
  import nibabel as nib
103
  from PIL import Image
104
  from torch.utils.data import Dataset, DataLoader
105
  import streamlit as st
 
106
 
107
- # Function to download the zip file from the URL
108
- def download_zip(url, save_path):
109
  response = requests.get(url)
110
- with open(save_path, 'wb') as f:
111
- f.write(response.content)
112
-
113
- # Function to extract zip file
114
- def extract_zip(zip_file, extract_to):
115
- with zipfile.ZipFile(zip_file, 'r') as zip_ref:
116
- zip_ref.extractall(extract_to)
117
 
118
  # Preprocess images
119
  def preprocess_image(image_path):
@@ -188,20 +92,14 @@ def fine_tune_model(train_loader):
188
  # Streamlit UI for Fine-tuning
189
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
190
 
191
- zip_file_url = "https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip"
192
-
193
  if st.button("Start Training"):
194
  extraction_dir = "extracted_files"
195
- zip_file_path = "archive_5.zip"
196
  os.makedirs(extraction_dir, exist_ok=True)
197
 
198
- # Download the zip file
199
- st.write("Downloading the zip file...")
200
- download_zip(zip_file_url, zip_file_path)
201
-
202
- # Extract the zip file
203
- st.write("Extracting files...")
204
- extract_zip(zip_file_path, extraction_dir)
205
 
206
  # Prepare dataset
207
  image_paths, labels = prepare_dataset(extraction_dir)
@@ -209,31 +107,8 @@ if st.button("Start Training"):
209
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
210
 
211
  # Fine-tune the model
212
- st.write("Fine-tuning the model...")
213
  final_loss = fine_tune_model(train_loader)
214
  st.write(f"Training Complete with Final Loss: {final_loss}")
215
 
216
- if st.button("Start Training"):
217
- extraction_dir = "extracted_files"
218
- zip_file_path = "archive_5.zip"
219
- os.makedirs(extraction_dir, exist_ok=True)
220
-
221
- # Download the zip file
222
- st.write("Downloading the zip file...")
223
- download_zip(zip_file_url, zip_file_path)
224
-
225
- # Extract the zip file
226
- st.write("Extracting files...")
227
- extract_zip(zip_file_path, extraction_dir)
228
-
229
- # Prepare dataset
230
- image_paths, labels = prepare_dataset(extraction_dir)
231
- dataset = CustomImageDataset(image_paths, labels)
232
- train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
233
-
234
- # Fine-tune the model
235
- st.write("Fine-tuning the model...")
236
- final_loss = fine_tune_model(train_loader)
237
- st.write(f"Training Complete with Final Loss: {final_loss}")
238
 
239
 
 
1
  import os
2
  import zipfile
 
3
  import numpy as np
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import requests
 
 
6
  from transformers import ViTForImageClassification, AdamW
7
  import nibabel as nib
8
  from PIL import Image
9
  from torch.utils.data import Dataset, DataLoader
10
  import streamlit as st
11
+ from io import BytesIO
12
 
13
+ # Function to download and extract zip file from URL
14
+ def extract_zip_from_url(url, extract_to):
15
  response = requests.get(url)
16
+ if response.status_code == 200:
17
+ with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
18
+ zip_ref.extractall(extract_to)
19
+ else:
20
+ raise ValueError(f"Unable to download zip file: {url}")
 
 
21
 
22
  # Preprocess images
23
  def preprocess_image(image_path):
 
92
  # Streamlit UI for Fine-tuning
93
  st.title("Fine-tune ViT on MRI/CT Scans for MS & Neurodegenerative Diseases")
94
 
95
+ # Input zip file URL
96
+ zip_file_url = st.text_input("https://huggingface.co/spaces/Tanusree88/ViT-MRI-FineTuning/resolve/main/neuroniiimages.zip")
97
  if st.button("Start Training"):
98
  extraction_dir = "extracted_files"
 
99
  os.makedirs(extraction_dir, exist_ok=True)
100
 
101
+ # Extract the zip file from URL
102
+ extract_zip_from_url(zip_file_url, extraction_dir)
 
 
 
 
 
103
 
104
  # Prepare dataset
105
  image_paths, labels = prepare_dataset(extraction_dir)
 
107
  train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
108
 
109
  # Fine-tune the model
 
110
  final_loss = fine_tune_model(train_loader)
111
  st.write(f"Training Complete with Final Loss: {final_loss}")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114