Tanusree88 commited on
Commit
aade1d5
·
verified ·
1 Parent(s): 21c7368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -45
app.py CHANGED
@@ -1,48 +1,11 @@
 
1
  import os
2
- import streamlit as st
3
- import torch
4
- from torch.utils.data import DataLoader, TensorDataset
5
- from preprocessing import preprocess_image
6
- from train import fine_tune_model
7
 
8
- # Define a function to load and preprocess images
9
- def load_images(image_paths):
10
- images = []
11
- labels = [] # Assuming you have labels, adjust as needed
12
 
13
- for path in image_paths:
14
- # Preprocess each image and append to the list
15
- image_tensor = preprocess_image(path)
16
- images.append(image_tensor)
17
-
18
- # Append corresponding label (this is just an example; adjust as needed)
19
- label = ... # Replace with logic to get the label for the image
20
- labels.append(label)
21
-
22
- # Stack images into a single tensor and create a DataLoader
23
- images_tensor = torch.stack(images) # Stack into a single tensor
24
- labels_tensor = torch.tensor(labels) # Convert labels to tensor
25
- dataset = TensorDataset(images_tensor, labels_tensor)
26
- return DataLoader(dataset, batch_size=16, shuffle=True) # Adjust batch_size as needed
27
-
28
- # Streamlit UI
29
- st.title("MRI Image Fine-Tuning with ViT")
30
-
31
- # Upload images or provide a directory path
32
- image_directory = st.text_input("Enter the directory of images:")
33
-
34
- if st.button("Load Images"):
35
- # Get all image paths from the directory
36
- if image_directory and os.path.isdir(image_directory):
37
- image_paths = [os.path.join(image_directory, f) for f in os.listdir(image_directory) if f.endswith(('.nii', '.jpg', '.jpeg'))]
38
- st.success(f"Loaded {len(image_paths)} images.")
39
-
40
- # Preprocess images and create DataLoader
41
- train_loader = load_images(image_paths)
42
-
43
- # Button to start training
44
- if st.button("Start Training"):
45
- final_loss = fine_tune_model(train_loader) # Trigger fine-tuning
46
- st.write(f"Training complete with final loss: {final_loss}")
47
- else:
48
- st.error("Please enter a valid directory.")
 
1
+ import zipfile
2
  import os
 
 
 
 
 
3
 
4
+ def extract_zip(zip_file, extract_to):
5
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
6
+ zip_ref.extractall(extract_to)
 
7
 
8
+ # Example usage
9
+ zip_file_path = 'path_to_your_zip_file.zip'
10
+ extract_to = 'path_to_extracted_images/'
11
+ extract_zip(zip_file_path, extract_to)