Spaces:
Sleeping
Sleeping
Upload 30 files
Browse files- Dockerfile +4 -2
- app.py +66 -0
- config/.kaggle/kaggle.json +1 -0
- config/__init__.py +0 -0
- config/__pycache__/__init__.cpython-311.pyc +0 -0
- config/__pycache__/configure.cpython-311.pyc +0 -0
- config/configure.py +5 -0
- data/__init__.py +0 -0
- main.py +62 -0
- notebook/MRI-Segmentation-Tutorial.ipynb +0 -0
- requirements.txt +18 -0
- src/__init__.py +20 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-311.pyc +0 -0
- src/data/__pycache__/data_ingestion.cpython-311.pyc +0 -0
- src/data/__pycache__/data_preprocess.cpython-311.pyc +0 -0
- src/data/data_ingestion.py +22 -0
- src/data/data_preprocess.py +113 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-311.pyc +0 -0
- src/model/__pycache__/unet.cpython-311.pyc +0 -0
- src/model/model/best_model.pth +3 -0
- src/model/unet.py +102 -0
- src/pipelines/__init__.py +0 -0
- src/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
- src/pipelines/__pycache__/predict.cpython-311.pyc +0 -0
- src/pipelines/__pycache__/training.cpython-311.pyc +0 -0
- src/pipelines/predict.py +70 -0
- src/pipelines/training.py +120 -0
Dockerfile
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
# lightweight python
|
2 |
FROM python:3.11
|
3 |
|
4 |
-
|
5 |
# Copy local code to the container image.
|
6 |
WORKDIR /app
|
7 |
COPY . /app
|
8 |
|
|
|
|
|
9 |
# Install dependencies
|
10 |
RUN pip install -r requirements.txt
|
11 |
-
EXPOSE
|
12 |
# Run the streamlit on container startup
|
13 |
CMD ["streamlit", "run", "app.py"]
|
|
|
1 |
# lightweight python
|
2 |
FROM python:3.11
|
3 |
|
4 |
+
RUN pip install --upgrade pip
|
5 |
# Copy local code to the container image.
|
6 |
WORKDIR /app
|
7 |
COPY . /app
|
8 |
|
9 |
+
|
10 |
+
|
11 |
# Install dependencies
|
12 |
RUN pip install -r requirements.txt
|
13 |
+
EXPOSE 80
|
14 |
# Run the streamlit on container startup
|
15 |
CMD ["streamlit", "run", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.model.unet import UNet
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import albumentations as A
|
6 |
+
from albumentations.pytorch import ToTensorV2
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import config.configure as config
|
10 |
+
from src.pipelines.predict import predict_mask
|
11 |
+
|
12 |
+
model = UNet(3, 1, [64, 128, 256, 512])
|
13 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
+
|
15 |
+
model.load_state_dict(torch.load(config.SAVE_MODEL_PATH, map_location=torch.device(device)))
|
16 |
+
# Set up transformations for the input image
|
17 |
+
|
18 |
+
|
19 |
+
transform = A.Compose([
|
20 |
+
A.Resize(224, 224, p=1.0),
|
21 |
+
ToTensorV2(),
|
22 |
+
])
|
23 |
+
# Streamlit app
|
24 |
+
def main():
|
25 |
+
st.title("MRI segmenation App")
|
26 |
+
|
27 |
+
# Upload image through Streamlit
|
28 |
+
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
29 |
+
|
30 |
+
if uploaded_image is not None:
|
31 |
+
# Display the uploaded and processed images side by side
|
32 |
+
col1, col2 = st.columns(2) # Using beta_columns for side-by-side layout
|
33 |
+
|
34 |
+
# Display the uploaded image in the first column
|
35 |
+
col1.header("Original Image")
|
36 |
+
col1.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
37 |
+
|
38 |
+
# Process the image (replace this with your processing logic)
|
39 |
+
processed_image = generate_image(uploaded_image)
|
40 |
+
|
41 |
+
# Display the processed image in the second column
|
42 |
+
col2.header("Processed Image")
|
43 |
+
col2.image(processed_image, caption="Processed Image", use_column_width=True)
|
44 |
+
|
45 |
+
# Function to generate an image using the PyTorch model
|
46 |
+
def generate_image(uploaded_image):
|
47 |
+
# Load the uploaded image
|
48 |
+
input_image = Image.open(uploaded_image)
|
49 |
+
|
50 |
+
image = np.array(input_image).astype(np.float32) / 255.
|
51 |
+
# Apply transformations
|
52 |
+
input_tensor = transform(image=image)["image"].unsqueeze(0)
|
53 |
+
|
54 |
+
# Generate an image using the PyTorch model
|
55 |
+
mask = predict_mask(data=input_tensor, device=device, model=model, inference=True)
|
56 |
+
mask = mask[0].permute(1, 2, 0)
|
57 |
+
image = input_tensor[0].permute(1, 2, 0)
|
58 |
+
|
59 |
+
mask = image + mask*0.3
|
60 |
+
mask = mask.permute(2, 0, 1)
|
61 |
+
mask = transforms.ToPILImage()(mask)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
config/.kaggle/kaggle.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_API"}
|
config/__init__.py
ADDED
File without changes
|
config/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (145 Bytes). View file
|
|
config/__pycache__/configure.cpython-311.pyc
ADDED
Binary file (529 Bytes). View file
|
|
config/configure.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
mask_images_path = os.getcwd() + "/data/lgg-mri-segmentation/kaggle_3m/*/*_mask.tif"
|
3 |
+
SAVE_MODEL_PATH = os.getcwd() + "/src/model/model/best_model.pth"
|
4 |
+
SAVE_DATA_PATH = os.getcwd() + "/data"
|
5 |
+
DATASET_NAME = "mateuszbuda/lgg-mri-segmentation"
|
data/__init__.py
ADDED
File without changes
|
main.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torchsummary
|
6 |
+
import torchview
|
7 |
+
import config.configure as config
|
8 |
+
from src import logger
|
9 |
+
from src.data.data_ingestion import DataIngestion
|
10 |
+
from src.data.data_preprocess import data_loaders
|
11 |
+
from src.pipelines.training import model_fit
|
12 |
+
from src.model.unet import UNet
|
13 |
+
## graphviiz
|
14 |
+
STAGE_NAME = "Data Ingestion stage"
|
15 |
+
try:
|
16 |
+
logger.info(f">>>>>>>> Starting {STAGE_NAME} <<<<<<<<")
|
17 |
+
data_ingestion = DataIngestion()
|
18 |
+
data_ingestion.download()
|
19 |
+
except Exception as e:
|
20 |
+
logger.exception(e)
|
21 |
+
raise e
|
22 |
+
|
23 |
+
STAGE_NAME = 'Training'
|
24 |
+
BATCH_SIZE = 32
|
25 |
+
NUM_WORKERS = 3
|
26 |
+
EPOCHS = 50
|
27 |
+
PATH = config.SAVE_MODEL_PATH
|
28 |
+
|
29 |
+
try:
|
30 |
+
logger.info(f'Preparing DataLoders')
|
31 |
+
|
32 |
+
# getting the dataloaders
|
33 |
+
train_loader, valid_loader = data_loaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_split=True)
|
34 |
+
|
35 |
+
# fitting the model
|
36 |
+
loss_fn = nn.BCEWithLogitsLoss()
|
37 |
+
in_channels = 3
|
38 |
+
out_channels = 1
|
39 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
40 |
+
features = [64, 128, 256, 512]
|
41 |
+
model = UNet(in_channels=in_channels, out_channels=out_channels, features=features)
|
42 |
+
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)
|
43 |
+
|
44 |
+
|
45 |
+
# starting the training stage
|
46 |
+
logger.info(f"Strating {STAGE_NAME} Stage \n\n ==============")
|
47 |
+
|
48 |
+
summary = model_fit(
|
49 |
+
epochs=EPOCHS,
|
50 |
+
model=model,
|
51 |
+
device=device,
|
52 |
+
train_loader=train_loader,
|
53 |
+
valid_loader=valid_loader,
|
54 |
+
criterion=loss_fn,
|
55 |
+
optimizer=optimizer,
|
56 |
+
PATH=PATH
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
logger.exception(e)
|
62 |
+
raise e
|
notebook/MRI-Segmentation-Tutorial.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core Libraries
|
2 |
+
numpy>=1.21.0
|
3 |
+
pandas>=1.3.3
|
4 |
+
kaggle>=1.6.3
|
5 |
+
# Data Visualization
|
6 |
+
matplotlib>=3.4.3
|
7 |
+
seaborn>=0.13.0
|
8 |
+
pillow>=9.3.0
|
9 |
+
|
10 |
+
tqdm>=4.66.0
|
11 |
+
|
12 |
+
# Web Application Framework
|
13 |
+
streamlit>=1.28.0
|
14 |
+
|
15 |
+
scikit-learn>=0.24.2
|
16 |
+
torch==2.1.2
|
17 |
+
torchvision>=0.16.2
|
18 |
+
albumentations>=1.3.1
|
src/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
# Configure the logging settings
|
4 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
5 |
+
|
6 |
+
# Create a logger with the name of the current module or script
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
# Create a FileHandler and set the log file path
|
10 |
+
if not os.path.exists("logs/"):
|
11 |
+
os.makedirs('logs/')
|
12 |
+
log_file_path = 'logs/my_log_file.log'
|
13 |
+
file_handler = logging.FileHandler(log_file_path)
|
14 |
+
|
15 |
+
# Create a formatter and set it for the FileHandler
|
16 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
17 |
+
file_handler.setFormatter(formatter)
|
18 |
+
|
19 |
+
# Add the FileHandler to the logger
|
20 |
+
logger.addHandler(file_handler)
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (927 Bytes). View file
|
|
src/data/__init__.py
ADDED
File without changes
|
src/data/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (147 Bytes). View file
|
|
src/data/__pycache__/data_ingestion.cpython-311.pyc
ADDED
Binary file (1.58 kB). View file
|
|
src/data/__pycache__/data_preprocess.cpython-311.pyc
ADDED
Binary file (6.05 kB). View file
|
|
src/data/data_ingestion.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from src import logger
|
3 |
+
import config.configure as config
|
4 |
+
print(os.getcwd())
|
5 |
+
os.environ["KAGGLE_CONFIG_DIR"] = "config/.kaggle/"
|
6 |
+
import kaggle
|
7 |
+
class DataIngestion():
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
def download(self):
|
11 |
+
logger.info(f"Downloading dataset into data/")
|
12 |
+
dataset_name = config.DATASET_NAME
|
13 |
+
|
14 |
+
# Replace '/path/to/your/folder' with the path to the folder where you want to save the dataset
|
15 |
+
output_folder = config.SAVE_DATA_PATH
|
16 |
+
|
17 |
+
# Download the dataset to the specified folder
|
18 |
+
kaggle.api.dataset_download_files(dataset_name, path=output_folder, unzip=True)
|
19 |
+
logger.info(f"Download Completed")
|
20 |
+
|
21 |
+
if __name__ == '__main__':
|
22 |
+
data = DataIngestion()
|
src/data/data_preprocess.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
import albumentations as A
|
9 |
+
from albumentations.pytorch import ToTensorV2
|
10 |
+
from PIL import Image
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
from config.configure import mask_images_path
|
13 |
+
from src import logger
|
14 |
+
def get_dataframe(path: str) -> pd.DataFrame:
|
15 |
+
"""
|
16 |
+
Create a DataFrame containing image paths, mask paths, and labels.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
path (str): path [mask_images]
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
pd.DataFrame: DataFrame with image paths, mask paths, and labels.
|
23 |
+
"""
|
24 |
+
|
25 |
+
image_masks = glob.glob(path)
|
26 |
+
image_paths = [file_path.replace("_mask", '') for file_path in image_masks]
|
27 |
+
|
28 |
+
def labels(mask_path):
|
29 |
+
label = []
|
30 |
+
for mask in mask_path:
|
31 |
+
img = Image.open(mask)
|
32 |
+
label.append(1) if np.array(img).sum() > 0 else label.append(0)
|
33 |
+
return label
|
34 |
+
|
35 |
+
mask_labels = labels(image_masks)
|
36 |
+
|
37 |
+
df = pd.DataFrame({
|
38 |
+
'image_path': image_paths,
|
39 |
+
'mask_path': image_masks,
|
40 |
+
'label': mask_labels
|
41 |
+
})
|
42 |
+
|
43 |
+
return df
|
44 |
+
|
45 |
+
class MRIDataset(Dataset):
|
46 |
+
def __init__(self, paths, transform):
|
47 |
+
"""
|
48 |
+
Custom dataset for MRI images.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
paths (pd.DataFrame): DataFrame containing mask paths.
|
52 |
+
transform: Data augmentation and transformation pipeline.
|
53 |
+
"""
|
54 |
+
self.paths = paths
|
55 |
+
self.transform = transform
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.paths)
|
59 |
+
|
60 |
+
def __getitem__(self, idx):
|
61 |
+
image_path, mask_path = self.paths.iloc[idx]
|
62 |
+
image = Image.open(image_path)
|
63 |
+
mask = Image.open(mask_path)
|
64 |
+
|
65 |
+
image = np.array(image).astype(np.float32) / 255.
|
66 |
+
mask = np.array(mask).astype(np.float32) / 255.
|
67 |
+
|
68 |
+
if self.transform:
|
69 |
+
transformed = self.transform(image=image, mask=mask)
|
70 |
+
return transformed['image'], transformed['mask'].unsqueeze(0)
|
71 |
+
else:
|
72 |
+
transformed = ToTensorV2()(image=image, mask=mask)
|
73 |
+
return transformed['image'], transformed['mask'].unsqueeze(0)
|
74 |
+
|
75 |
+
|
76 |
+
def data_loaders(batch_size,num_workers, train_split=False) -> DataLoader:
|
77 |
+
|
78 |
+
logger.info(f"Preprocessing Data")
|
79 |
+
df = get_dataframe(mask_images_path)
|
80 |
+
|
81 |
+
train_transforms = A.Compose([
|
82 |
+
A.Resize(224, 224, p=1.0),
|
83 |
+
A.RandomBrightnessContrast(p=0.2),
|
84 |
+
A.HorizontalFlip(p=0.5),
|
85 |
+
A.VerticalFlip(p=0.5),
|
86 |
+
ToTensorV2(),
|
87 |
+
])
|
88 |
+
|
89 |
+
# Only reshape val and test data
|
90 |
+
val_transforms = A.Compose([
|
91 |
+
A.Resize(224, 224, p=1.0),
|
92 |
+
ToTensorV2(),
|
93 |
+
])
|
94 |
+
|
95 |
+
# splitting the dataset
|
96 |
+
train_x, val_x, train_y, val_y = train_test_split(df.drop('label',axis=1), df.label,test_size=0.3)
|
97 |
+
val_x , test_x, val_y, test_y = train_test_split(val_x, val_y, test_size = 0.2)
|
98 |
+
|
99 |
+
train_data = MRIDataset(train_x, train_transforms)
|
100 |
+
val_data = MRIDataset(val_x, val_transforms)
|
101 |
+
test_data = MRIDataset(test_x[test_y == 1], val_transforms)
|
102 |
+
|
103 |
+
|
104 |
+
# train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
|
105 |
+
|
106 |
+
if train_split:
|
107 |
+
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
108 |
+
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
109 |
+
|
110 |
+
return train_loader, val_loader
|
111 |
+
else:
|
112 |
+
test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
|
113 |
+
return test_loader
|
src/model/__init__.py
ADDED
File without changes
|
src/model/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (148 Bytes). View file
|
|
src/model/__pycache__/unet.cpython-311.pyc
ADDED
Binary file (6.69 kB). View file
|
|
src/model/model/best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80bca9fca6dd62bb74e5072bcac8a4e2d232d43b6f45e0202bf6d5a353cd2b70
|
3 |
+
size 124203732
|
src/model/unet.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class DownSampling(nn.Module):
|
5 |
+
|
6 |
+
def __init__(self, in_channels, out_channels, max_pool):
|
7 |
+
"""
|
8 |
+
DownSampling block in the U-Net architecture.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
in_channels (int): Number of input channels.
|
12 |
+
out_channels (int): Number of output channels.
|
13 |
+
max_pool (bool): Whether to use max pooling.
|
14 |
+
"""
|
15 |
+
super(DownSampling, self).__init__()
|
16 |
+
self.max_pool = max_pool
|
17 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
18 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
19 |
+
self.batchnorm2d = nn.BatchNorm2d(out_channels)
|
20 |
+
self.relu = nn.ReLU()
|
21 |
+
self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.conv1(x)
|
25 |
+
x = self.conv2(x)
|
26 |
+
|
27 |
+
x = self.relu(self.batchnorm2d(x))
|
28 |
+
skip_connection = x
|
29 |
+
|
30 |
+
if self.max_pool:
|
31 |
+
next_layer = self.maxpool2d(x)
|
32 |
+
else:
|
33 |
+
return x
|
34 |
+
return next_layer, skip_connection
|
35 |
+
|
36 |
+
class UpSampling(nn.Module):
|
37 |
+
def __init__(self, in_channels, out_channels):
|
38 |
+
"""
|
39 |
+
UpSampling block in the U-Net architecture.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
in_channels (int): Number of input channels.
|
43 |
+
out_channels (int): Number of output channels.
|
44 |
+
"""
|
45 |
+
super(UpSampling, self).__init__()
|
46 |
+
self.up = nn.ConvTranspose2d(in_channels, out_channels=out_channels, kernel_size=2, stride=2)
|
47 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
48 |
+
self.relu = nn.ReLU()
|
49 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
50 |
+
self.batchnorm = nn.BatchNorm2d(out_channels)
|
51 |
+
|
52 |
+
def forward(self, x, prev_skip):
|
53 |
+
x = self.up(x)
|
54 |
+
x = torch.cat((x, prev_skip), dim=1)
|
55 |
+
x = self.conv1(x)
|
56 |
+
x = self.conv2(x)
|
57 |
+
next_layer = self.relu(self.batchnorm(x))
|
58 |
+
return next_layer
|
59 |
+
|
60 |
+
class UNet(nn.Module):
|
61 |
+
|
62 |
+
"""
|
63 |
+
U-Net architecture.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
in_channels (int): Number of input channels.
|
67 |
+
out_channels (int): Number of output channels.
|
68 |
+
features (list): List of feature sizes for downsampling and upsampling.
|
69 |
+
"""
|
70 |
+
def __init__(self, in_channels, out_channels, features):
|
71 |
+
super(UNet, self).__init__()
|
72 |
+
self.ups = nn.ModuleList()
|
73 |
+
self.downs = nn.ModuleList()
|
74 |
+
|
75 |
+
for feature in features:
|
76 |
+
self.downs.append(DownSampling(in_channels, feature, True))
|
77 |
+
in_channels = feature
|
78 |
+
|
79 |
+
for feature in reversed(features):
|
80 |
+
self.ups.append(UpSampling(2 * feature, feature))
|
81 |
+
|
82 |
+
self.bottleneck = DownSampling(features[-1], 2 * features[-1], False)
|
83 |
+
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
skip_connections = []
|
87 |
+
for down in self.downs:
|
88 |
+
x, skip_connection = down(x)
|
89 |
+
skip_connections.append(skip_connection)
|
90 |
+
skip_connections = skip_connections[::-1]
|
91 |
+
x = self.bottleneck(x)
|
92 |
+
for i, up in enumerate(self.ups):
|
93 |
+
x = up(x, skip_connections[i])
|
94 |
+
|
95 |
+
return self.final_conv(x)
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
#Example Usage
|
99 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
100 |
+
features = [64, 128, 256, 512]
|
101 |
+
model = UNet(1, 1, features=features).to(device)
|
102 |
+
print(model(torch.rand(1, 1, 512, 512)).shape)
|
src/pipelines/__init__.py
ADDED
File without changes
|
src/pipelines/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/pipelines/__pycache__/predict.cpython-311.pyc
ADDED
Binary file (3.66 kB). View file
|
|
src/pipelines/__pycache__/training.cpython-311.pyc
ADDED
Binary file (5.27 kB). View file
|
|
src/pipelines/predict.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from tqdm import tqdm
|
4 |
+
from typing import *
|
5 |
+
from src import logger
|
6 |
+
|
7 |
+
def predict_mask(
|
8 |
+
data: Any,
|
9 |
+
device: Any,
|
10 |
+
model: nn.Module,
|
11 |
+
inference: bool,
|
12 |
+
valid_loader=None,
|
13 |
+
criterion=None,
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
predicts mask for the image
|
17 |
+
Args:
|
18 |
+
data (Any): image data for predicting
|
19 |
+
model (nn.Module): model for training
|
20 |
+
device (0/'cud'/'cpu'/Any): name of device
|
21 |
+
inference (bool): Whether to evaluate or predict
|
22 |
+
valid_Loader (nn.Module): test loader for training
|
23 |
+
criterion (nn.Module): loss criteria
|
24 |
+
|
25 |
+
Example:
|
26 |
+
>>> train(
|
27 |
+
>>> data = torch.FloatTensor,
|
28 |
+
>>> model=model,
|
29 |
+
>>> device=0/'cuda'/'cpu'
|
30 |
+
>>> ingerence=0
|
31 |
+
>>> valid_loader= test_loader
|
32 |
+
>>> criterion= fn_loss
|
33 |
+
"""
|
34 |
+
|
35 |
+
if inference:
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
image = data.type(torch.FloatTensor).to(device)
|
39 |
+
model = model.to(device)
|
40 |
+
pred = model(image)
|
41 |
+
pred = torch.sigmoid(pred)
|
42 |
+
mask = (pred > 0.6).float()
|
43 |
+
|
44 |
+
return mask.cpu().detach()
|
45 |
+
else:
|
46 |
+
with torch.no_grad():
|
47 |
+
val_Loss = 0
|
48 |
+
val_Dicescore = 0
|
49 |
+
model.eval()
|
50 |
+
for x, y in tqdm(valid_loader):
|
51 |
+
x = x.type(torch.cuda.FloatTensor).to(device)
|
52 |
+
y = y.type(torch.cuda.FloatTensor).to(device)
|
53 |
+
|
54 |
+
predict = model(x)
|
55 |
+
loss = criterion(predict, y)
|
56 |
+
val_Loss += loss.item()
|
57 |
+
|
58 |
+
predict = torch.sigmoid(predict)
|
59 |
+
predict = (predict > 0.5).float()
|
60 |
+
|
61 |
+
dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
|
62 |
+
try:
|
63 |
+
val_Dicescore += dice_score.cpu().item()
|
64 |
+
except:
|
65 |
+
val_Dicescore += dice_score
|
66 |
+
|
67 |
+
val_Loss /= len(valid_loader)
|
68 |
+
val_Dicescore /= len(valid_loader)
|
69 |
+
|
70 |
+
logger.info(f"Test Loss: {val_Loss} - Dice Score: {val_Dicescore}")
|
src/pipelines/training.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch.nn as nn
|
4 |
+
from src import logger
|
5 |
+
from typing import *
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings('ignore')
|
8 |
+
def model_fit(
|
9 |
+
epochs: int,
|
10 |
+
model: nn.Module,
|
11 |
+
device: Any,
|
12 |
+
train_loader: Any,
|
13 |
+
valid_loader: Any,
|
14 |
+
criterion: nn.Module,
|
15 |
+
optimizer: nn.Module,
|
16 |
+
PATH: str
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
|
20 |
+
Args:
|
21 |
+
epochs (int): # of epochs
|
22 |
+
model (nn.Module): model for training
|
23 |
+
device (Union[int, str]): number or name of device
|
24 |
+
train_loader (Any): pytorch loader for trainset
|
25 |
+
valid_loader (Any): pytorch loader for testset
|
26 |
+
criterion (nn.Module): loss critiria
|
27 |
+
optimizer (nn.Module): optimizer for model training
|
28 |
+
path (str): path for saving model
|
29 |
+
|
30 |
+
|
31 |
+
Example:
|
32 |
+
>>> train(
|
33 |
+
>>> epochs=25,
|
34 |
+
>>> model=model,
|
35 |
+
>>> device=0/'cuda'/'cpu',
|
36 |
+
>>> train_loader=train_loader,
|
37 |
+
>>> valid_loader=valid_loader,
|
38 |
+
>>> criterion=fn_loss,
|
39 |
+
>>> optimizer=optimizer)
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
best_DICESCORE = 0
|
44 |
+
model.to(device)
|
45 |
+
summary = {
|
46 |
+
'train_loss' : [],
|
47 |
+
'train_dice' : [],
|
48 |
+
'valid_loss' : [],
|
49 |
+
'valid_dice' : []
|
50 |
+
}
|
51 |
+
for epoch in range(epochs):
|
52 |
+
logger.info(f"EPOCH {epoch}/{epochs}")
|
53 |
+
train_Loss = 0
|
54 |
+
train_Dicescore = 0
|
55 |
+
model.train()
|
56 |
+
for x, y in tqdm(train_loader):
|
57 |
+
x = x.type(torch.FloatTensor).to(device)
|
58 |
+
y = y.type(torch.FloatTensor).to(device)
|
59 |
+
|
60 |
+
predict = model(x)
|
61 |
+
loss = criterion(predict, y)
|
62 |
+
train_Loss += loss.item()
|
63 |
+
|
64 |
+
optimizer.zero_grad()
|
65 |
+
loss.backward()
|
66 |
+
optimizer.step()
|
67 |
+
|
68 |
+
predict = torch.sigmoid(predict)
|
69 |
+
predict = (predict > 0.5).float()
|
70 |
+
|
71 |
+
dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
|
72 |
+
|
73 |
+
try:
|
74 |
+
train_Dicescore += dice_score.cpu().item()
|
75 |
+
except:
|
76 |
+
train_Dicescore += dice_score
|
77 |
+
|
78 |
+
train_Loss /= len(train_loader)
|
79 |
+
train_Dicescore /= len(train_loader)
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
val_Loss = 0
|
85 |
+
val_Dicescore = 0
|
86 |
+
model.eval()
|
87 |
+
for x, y in tqdm(valid_loader):
|
88 |
+
x = x.type(torch.FloatTensor).to(device)
|
89 |
+
y = y.type(torch.FloatTensor).to(device)
|
90 |
+
|
91 |
+
predict = model(x)
|
92 |
+
loss = criterion(predict, y)
|
93 |
+
val_Loss += loss.item()
|
94 |
+
|
95 |
+
predict = torch.sigmoid(predict)
|
96 |
+
predict = (predict > 0.5).float()
|
97 |
+
|
98 |
+
dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
|
99 |
+
try:
|
100 |
+
val_Dicescore += dice_score.cpu().item()
|
101 |
+
except:
|
102 |
+
val_Dicescore += dice_score
|
103 |
+
|
104 |
+
val_Loss /= len(valid_loader)
|
105 |
+
val_Dicescore /= len(valid_loader)
|
106 |
+
|
107 |
+
|
108 |
+
logger.info(f"Loss: {train_Loss} - Dice Score: {train_Dicescore} - Validation Loss: {val_Loss} - Validation Dice Score: {val_Dicescore}")
|
109 |
+
|
110 |
+
if val_Dicescore > best_DICESCORE:
|
111 |
+
best_DICESCORE = val_Dicescore
|
112 |
+
torch.save(model, PATH)
|
113 |
+
|
114 |
+
summary['train_loss'] = train_Loss
|
115 |
+
summary['train_dice'] = train_Dicescore
|
116 |
+
summary['valid_loss'] = val_Loss
|
117 |
+
summary['valid_dice'] = val_Dicescore
|
118 |
+
|
119 |
+
|
120 |
+
return summary
|