smishr-18 commited on
Commit
a578142
·
verified ·
1 Parent(s): aef1b28

Upload 30 files

Browse files
Dockerfile CHANGED
@@ -1,13 +1,15 @@
1
  # lightweight python
2
  FROM python:3.11
3
 
4
-
5
  # Copy local code to the container image.
6
  WORKDIR /app
7
  COPY . /app
8
 
 
 
9
  # Install dependencies
10
  RUN pip install -r requirements.txt
11
- EXPOSE 8051
12
  # Run the streamlit on container startup
13
  CMD ["streamlit", "run", "app.py"]
 
1
  # lightweight python
2
  FROM python:3.11
3
 
4
+ RUN pip install --upgrade pip
5
  # Copy local code to the container image.
6
  WORKDIR /app
7
  COPY . /app
8
 
9
+
10
+
11
  # Install dependencies
12
  RUN pip install -r requirements.txt
13
+ EXPOSE 80
14
  # Run the streamlit on container startup
15
  CMD ["streamlit", "run", "app.py"]
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model.unet import UNet
2
+ import streamlit as st
3
+ import torch
4
+ from torchvision import transforms
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ from PIL import Image
8
+ import numpy as np
9
+ import config.configure as config
10
+ from src.pipelines.predict import predict_mask
11
+
12
+ model = UNet(3, 1, [64, 128, 256, 512])
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+
15
+ model.load_state_dict(torch.load(config.SAVE_MODEL_PATH, map_location=torch.device(device)))
16
+ # Set up transformations for the input image
17
+
18
+
19
+ transform = A.Compose([
20
+ A.Resize(224, 224, p=1.0),
21
+ ToTensorV2(),
22
+ ])
23
+ # Streamlit app
24
+ def main():
25
+ st.title("MRI segmenation App")
26
+
27
+ # Upload image through Streamlit
28
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
29
+
30
+ if uploaded_image is not None:
31
+ # Display the uploaded and processed images side by side
32
+ col1, col2 = st.columns(2) # Using beta_columns for side-by-side layout
33
+
34
+ # Display the uploaded image in the first column
35
+ col1.header("Original Image")
36
+ col1.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
37
+
38
+ # Process the image (replace this with your processing logic)
39
+ processed_image = generate_image(uploaded_image)
40
+
41
+ # Display the processed image in the second column
42
+ col2.header("Processed Image")
43
+ col2.image(processed_image, caption="Processed Image", use_column_width=True)
44
+
45
+ # Function to generate an image using the PyTorch model
46
+ def generate_image(uploaded_image):
47
+ # Load the uploaded image
48
+ input_image = Image.open(uploaded_image)
49
+
50
+ image = np.array(input_image).astype(np.float32) / 255.
51
+ # Apply transformations
52
+ input_tensor = transform(image=image)["image"].unsqueeze(0)
53
+
54
+ # Generate an image using the PyTorch model
55
+ mask = predict_mask(data=input_tensor, device=device, model=model, inference=True)
56
+ mask = mask[0].permute(1, 2, 0)
57
+ image = input_tensor[0].permute(1, 2, 0)
58
+
59
+ mask = image + mask*0.3
60
+ mask = mask.permute(2, 0, 1)
61
+ mask = transforms.ToPILImage()(mask)
62
+ return mask
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
config/.kaggle/kaggle.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_API"}
config/__init__.py ADDED
File without changes
config/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (145 Bytes). View file
 
config/__pycache__/configure.cpython-311.pyc ADDED
Binary file (529 Bytes). View file
 
config/configure.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ mask_images_path = os.getcwd() + "/data/lgg-mri-segmentation/kaggle_3m/*/*_mask.tif"
3
+ SAVE_MODEL_PATH = os.getcwd() + "/src/model/model/best_model.pth"
4
+ SAVE_DATA_PATH = os.getcwd() + "/data"
5
+ DATASET_NAME = "mateuszbuda/lgg-mri-segmentation"
data/__init__.py ADDED
File without changes
main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn as nn
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import torchsummary
6
+ import torchview
7
+ import config.configure as config
8
+ from src import logger
9
+ from src.data.data_ingestion import DataIngestion
10
+ from src.data.data_preprocess import data_loaders
11
+ from src.pipelines.training import model_fit
12
+ from src.model.unet import UNet
13
+ ## graphviiz
14
+ STAGE_NAME = "Data Ingestion stage"
15
+ try:
16
+ logger.info(f">>>>>>>> Starting {STAGE_NAME} <<<<<<<<")
17
+ data_ingestion = DataIngestion()
18
+ data_ingestion.download()
19
+ except Exception as e:
20
+ logger.exception(e)
21
+ raise e
22
+
23
+ STAGE_NAME = 'Training'
24
+ BATCH_SIZE = 32
25
+ NUM_WORKERS = 3
26
+ EPOCHS = 50
27
+ PATH = config.SAVE_MODEL_PATH
28
+
29
+ try:
30
+ logger.info(f'Preparing DataLoders')
31
+
32
+ # getting the dataloaders
33
+ train_loader, valid_loader = data_loaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_split=True)
34
+
35
+ # fitting the model
36
+ loss_fn = nn.BCEWithLogitsLoss()
37
+ in_channels = 3
38
+ out_channels = 1
39
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ features = [64, 128, 256, 512]
41
+ model = UNet(in_channels=in_channels, out_channels=out_channels, features=features)
42
+ optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)
43
+
44
+
45
+ # starting the training stage
46
+ logger.info(f"Strating {STAGE_NAME} Stage \n\n ==============")
47
+
48
+ summary = model_fit(
49
+ epochs=EPOCHS,
50
+ model=model,
51
+ device=device,
52
+ train_loader=train_loader,
53
+ valid_loader=valid_loader,
54
+ criterion=loss_fn,
55
+ optimizer=optimizer,
56
+ PATH=PATH
57
+ )
58
+
59
+
60
+ except Exception as e:
61
+ logger.exception(e)
62
+ raise e
notebook/MRI-Segmentation-Tutorial.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Libraries
2
+ numpy>=1.21.0
3
+ pandas>=1.3.3
4
+ kaggle>=1.6.3
5
+ # Data Visualization
6
+ matplotlib>=3.4.3
7
+ seaborn>=0.13.0
8
+ pillow>=9.3.0
9
+
10
+ tqdm>=4.66.0
11
+
12
+ # Web Application Framework
13
+ streamlit>=1.28.0
14
+
15
+ scikit-learn>=0.24.2
16
+ torch==2.1.2
17
+ torchvision>=0.16.2
18
+ albumentations>=1.3.1
src/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ # Configure the logging settings
4
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
5
+
6
+ # Create a logger with the name of the current module or script
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Create a FileHandler and set the log file path
10
+ if not os.path.exists("logs/"):
11
+ os.makedirs('logs/')
12
+ log_file_path = 'logs/my_log_file.log'
13
+ file_handler = logging.FileHandler(log_file_path)
14
+
15
+ # Create a formatter and set it for the FileHandler
16
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
17
+ file_handler.setFormatter(formatter)
18
+
19
+ # Add the FileHandler to the logger
20
+ logger.addHandler(file_handler)
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (927 Bytes). View file
 
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (147 Bytes). View file
 
src/data/__pycache__/data_ingestion.cpython-311.pyc ADDED
Binary file (1.58 kB). View file
 
src/data/__pycache__/data_preprocess.cpython-311.pyc ADDED
Binary file (6.05 kB). View file
 
src/data/data_ingestion.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src import logger
3
+ import config.configure as config
4
+ print(os.getcwd())
5
+ os.environ["KAGGLE_CONFIG_DIR"] = "config/.kaggle/"
6
+ import kaggle
7
+ class DataIngestion():
8
+ def __init__(self):
9
+ pass
10
+ def download(self):
11
+ logger.info(f"Downloading dataset into data/")
12
+ dataset_name = config.DATASET_NAME
13
+
14
+ # Replace '/path/to/your/folder' with the path to the folder where you want to save the dataset
15
+ output_folder = config.SAVE_DATA_PATH
16
+
17
+ # Download the dataset to the specified folder
18
+ kaggle.api.dataset_download_files(dataset_name, path=output_folder, unzip=True)
19
+ logger.info(f"Download Completed")
20
+
21
+ if __name__ == '__main__':
22
+ data = DataIngestion()
src/data/data_preprocess.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import albumentations as A
9
+ from albumentations.pytorch import ToTensorV2
10
+ from PIL import Image
11
+ from sklearn.model_selection import train_test_split
12
+ from config.configure import mask_images_path
13
+ from src import logger
14
+ def get_dataframe(path: str) -> pd.DataFrame:
15
+ """
16
+ Create a DataFrame containing image paths, mask paths, and labels.
17
+
18
+ Args:
19
+ path (str): path [mask_images]
20
+
21
+ Returns:
22
+ pd.DataFrame: DataFrame with image paths, mask paths, and labels.
23
+ """
24
+
25
+ image_masks = glob.glob(path)
26
+ image_paths = [file_path.replace("_mask", '') for file_path in image_masks]
27
+
28
+ def labels(mask_path):
29
+ label = []
30
+ for mask in mask_path:
31
+ img = Image.open(mask)
32
+ label.append(1) if np.array(img).sum() > 0 else label.append(0)
33
+ return label
34
+
35
+ mask_labels = labels(image_masks)
36
+
37
+ df = pd.DataFrame({
38
+ 'image_path': image_paths,
39
+ 'mask_path': image_masks,
40
+ 'label': mask_labels
41
+ })
42
+
43
+ return df
44
+
45
+ class MRIDataset(Dataset):
46
+ def __init__(self, paths, transform):
47
+ """
48
+ Custom dataset for MRI images.
49
+
50
+ Args:
51
+ paths (pd.DataFrame): DataFrame containing mask paths.
52
+ transform: Data augmentation and transformation pipeline.
53
+ """
54
+ self.paths = paths
55
+ self.transform = transform
56
+
57
+ def __len__(self):
58
+ return len(self.paths)
59
+
60
+ def __getitem__(self, idx):
61
+ image_path, mask_path = self.paths.iloc[idx]
62
+ image = Image.open(image_path)
63
+ mask = Image.open(mask_path)
64
+
65
+ image = np.array(image).astype(np.float32) / 255.
66
+ mask = np.array(mask).astype(np.float32) / 255.
67
+
68
+ if self.transform:
69
+ transformed = self.transform(image=image, mask=mask)
70
+ return transformed['image'], transformed['mask'].unsqueeze(0)
71
+ else:
72
+ transformed = ToTensorV2()(image=image, mask=mask)
73
+ return transformed['image'], transformed['mask'].unsqueeze(0)
74
+
75
+
76
+ def data_loaders(batch_size,num_workers, train_split=False) -> DataLoader:
77
+
78
+ logger.info(f"Preprocessing Data")
79
+ df = get_dataframe(mask_images_path)
80
+
81
+ train_transforms = A.Compose([
82
+ A.Resize(224, 224, p=1.0),
83
+ A.RandomBrightnessContrast(p=0.2),
84
+ A.HorizontalFlip(p=0.5),
85
+ A.VerticalFlip(p=0.5),
86
+ ToTensorV2(),
87
+ ])
88
+
89
+ # Only reshape val and test data
90
+ val_transforms = A.Compose([
91
+ A.Resize(224, 224, p=1.0),
92
+ ToTensorV2(),
93
+ ])
94
+
95
+ # splitting the dataset
96
+ train_x, val_x, train_y, val_y = train_test_split(df.drop('label',axis=1), df.label,test_size=0.3)
97
+ val_x , test_x, val_y, test_y = train_test_split(val_x, val_y, test_size = 0.2)
98
+
99
+ train_data = MRIDataset(train_x, train_transforms)
100
+ val_data = MRIDataset(val_x, val_transforms)
101
+ test_data = MRIDataset(test_x[test_y == 1], val_transforms)
102
+
103
+
104
+ # train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
105
+
106
+ if train_split:
107
+ train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
108
+ val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
109
+
110
+ return train_loader, val_loader
111
+ else:
112
+ test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
113
+ return test_loader
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (148 Bytes). View file
 
src/model/__pycache__/unet.cpython-311.pyc ADDED
Binary file (6.69 kB). View file
 
src/model/model/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80bca9fca6dd62bb74e5072bcac8a4e2d232d43b6f45e0202bf6d5a353cd2b70
3
+ size 124203732
src/model/unet.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DownSampling(nn.Module):
5
+
6
+ def __init__(self, in_channels, out_channels, max_pool):
7
+ """
8
+ DownSampling block in the U-Net architecture.
9
+
10
+ Args:
11
+ in_channels (int): Number of input channels.
12
+ out_channels (int): Number of output channels.
13
+ max_pool (bool): Whether to use max pooling.
14
+ """
15
+ super(DownSampling, self).__init__()
16
+ self.max_pool = max_pool
17
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
18
+ self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
19
+ self.batchnorm2d = nn.BatchNorm2d(out_channels)
20
+ self.relu = nn.ReLU()
21
+ self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
22
+
23
+ def forward(self, x):
24
+ x = self.conv1(x)
25
+ x = self.conv2(x)
26
+
27
+ x = self.relu(self.batchnorm2d(x))
28
+ skip_connection = x
29
+
30
+ if self.max_pool:
31
+ next_layer = self.maxpool2d(x)
32
+ else:
33
+ return x
34
+ return next_layer, skip_connection
35
+
36
+ class UpSampling(nn.Module):
37
+ def __init__(self, in_channels, out_channels):
38
+ """
39
+ UpSampling block in the U-Net architecture.
40
+
41
+ Args:
42
+ in_channels (int): Number of input channels.
43
+ out_channels (int): Number of output channels.
44
+ """
45
+ super(UpSampling, self).__init__()
46
+ self.up = nn.ConvTranspose2d(in_channels, out_channels=out_channels, kernel_size=2, stride=2)
47
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
48
+ self.relu = nn.ReLU()
49
+ self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
50
+ self.batchnorm = nn.BatchNorm2d(out_channels)
51
+
52
+ def forward(self, x, prev_skip):
53
+ x = self.up(x)
54
+ x = torch.cat((x, prev_skip), dim=1)
55
+ x = self.conv1(x)
56
+ x = self.conv2(x)
57
+ next_layer = self.relu(self.batchnorm(x))
58
+ return next_layer
59
+
60
+ class UNet(nn.Module):
61
+
62
+ """
63
+ U-Net architecture.
64
+
65
+ Args:
66
+ in_channels (int): Number of input channels.
67
+ out_channels (int): Number of output channels.
68
+ features (list): List of feature sizes for downsampling and upsampling.
69
+ """
70
+ def __init__(self, in_channels, out_channels, features):
71
+ super(UNet, self).__init__()
72
+ self.ups = nn.ModuleList()
73
+ self.downs = nn.ModuleList()
74
+
75
+ for feature in features:
76
+ self.downs.append(DownSampling(in_channels, feature, True))
77
+ in_channels = feature
78
+
79
+ for feature in reversed(features):
80
+ self.ups.append(UpSampling(2 * feature, feature))
81
+
82
+ self.bottleneck = DownSampling(features[-1], 2 * features[-1], False)
83
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
84
+
85
+ def forward(self, x):
86
+ skip_connections = []
87
+ for down in self.downs:
88
+ x, skip_connection = down(x)
89
+ skip_connections.append(skip_connection)
90
+ skip_connections = skip_connections[::-1]
91
+ x = self.bottleneck(x)
92
+ for i, up in enumerate(self.ups):
93
+ x = up(x, skip_connections[i])
94
+
95
+ return self.final_conv(x)
96
+
97
+ if __name__ == "__main__":
98
+ #Example Usage
99
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
100
+ features = [64, 128, 256, 512]
101
+ model = UNet(1, 1, features=features).to(device)
102
+ print(model(torch.rand(1, 1, 512, 512)).shape)
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (152 Bytes). View file
 
src/pipelines/__pycache__/predict.cpython-311.pyc ADDED
Binary file (3.66 kB). View file
 
src/pipelines/__pycache__/training.cpython-311.pyc ADDED
Binary file (5.27 kB). View file
 
src/pipelines/predict.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ from typing import *
5
+ from src import logger
6
+
7
+ def predict_mask(
8
+ data: Any,
9
+ device: Any,
10
+ model: nn.Module,
11
+ inference: bool,
12
+ valid_loader=None,
13
+ criterion=None,
14
+ ):
15
+ """
16
+ predicts mask for the image
17
+ Args:
18
+ data (Any): image data for predicting
19
+ model (nn.Module): model for training
20
+ device (0/'cud'/'cpu'/Any): name of device
21
+ inference (bool): Whether to evaluate or predict
22
+ valid_Loader (nn.Module): test loader for training
23
+ criterion (nn.Module): loss criteria
24
+
25
+ Example:
26
+ >>> train(
27
+ >>> data = torch.FloatTensor,
28
+ >>> model=model,
29
+ >>> device=0/'cuda'/'cpu'
30
+ >>> ingerence=0
31
+ >>> valid_loader= test_loader
32
+ >>> criterion= fn_loss
33
+ """
34
+
35
+ if inference:
36
+
37
+ with torch.no_grad():
38
+ image = data.type(torch.FloatTensor).to(device)
39
+ model = model.to(device)
40
+ pred = model(image)
41
+ pred = torch.sigmoid(pred)
42
+ mask = (pred > 0.6).float()
43
+
44
+ return mask.cpu().detach()
45
+ else:
46
+ with torch.no_grad():
47
+ val_Loss = 0
48
+ val_Dicescore = 0
49
+ model.eval()
50
+ for x, y in tqdm(valid_loader):
51
+ x = x.type(torch.cuda.FloatTensor).to(device)
52
+ y = y.type(torch.cuda.FloatTensor).to(device)
53
+
54
+ predict = model(x)
55
+ loss = criterion(predict, y)
56
+ val_Loss += loss.item()
57
+
58
+ predict = torch.sigmoid(predict)
59
+ predict = (predict > 0.5).float()
60
+
61
+ dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
62
+ try:
63
+ val_Dicescore += dice_score.cpu().item()
64
+ except:
65
+ val_Dicescore += dice_score
66
+
67
+ val_Loss /= len(valid_loader)
68
+ val_Dicescore /= len(valid_loader)
69
+
70
+ logger.info(f"Test Loss: {val_Loss} - Dice Score: {val_Dicescore}")
src/pipelines/training.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import torch.nn as nn
4
+ from src import logger
5
+ from typing import *
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
+ def model_fit(
9
+ epochs: int,
10
+ model: nn.Module,
11
+ device: Any,
12
+ train_loader: Any,
13
+ valid_loader: Any,
14
+ criterion: nn.Module,
15
+ optimizer: nn.Module,
16
+ PATH: str
17
+ ):
18
+ """
19
+
20
+ Args:
21
+ epochs (int): # of epochs
22
+ model (nn.Module): model for training
23
+ device (Union[int, str]): number or name of device
24
+ train_loader (Any): pytorch loader for trainset
25
+ valid_loader (Any): pytorch loader for testset
26
+ criterion (nn.Module): loss critiria
27
+ optimizer (nn.Module): optimizer for model training
28
+ path (str): path for saving model
29
+
30
+
31
+ Example:
32
+ >>> train(
33
+ >>> epochs=25,
34
+ >>> model=model,
35
+ >>> device=0/'cuda'/'cpu',
36
+ >>> train_loader=train_loader,
37
+ >>> valid_loader=valid_loader,
38
+ >>> criterion=fn_loss,
39
+ >>> optimizer=optimizer)
40
+ """
41
+
42
+
43
+ best_DICESCORE = 0
44
+ model.to(device)
45
+ summary = {
46
+ 'train_loss' : [],
47
+ 'train_dice' : [],
48
+ 'valid_loss' : [],
49
+ 'valid_dice' : []
50
+ }
51
+ for epoch in range(epochs):
52
+ logger.info(f"EPOCH {epoch}/{epochs}")
53
+ train_Loss = 0
54
+ train_Dicescore = 0
55
+ model.train()
56
+ for x, y in tqdm(train_loader):
57
+ x = x.type(torch.FloatTensor).to(device)
58
+ y = y.type(torch.FloatTensor).to(device)
59
+
60
+ predict = model(x)
61
+ loss = criterion(predict, y)
62
+ train_Loss += loss.item()
63
+
64
+ optimizer.zero_grad()
65
+ loss.backward()
66
+ optimizer.step()
67
+
68
+ predict = torch.sigmoid(predict)
69
+ predict = (predict > 0.5).float()
70
+
71
+ dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
72
+
73
+ try:
74
+ train_Dicescore += dice_score.cpu().item()
75
+ except:
76
+ train_Dicescore += dice_score
77
+
78
+ train_Loss /= len(train_loader)
79
+ train_Dicescore /= len(train_loader)
80
+
81
+
82
+
83
+ with torch.no_grad():
84
+ val_Loss = 0
85
+ val_Dicescore = 0
86
+ model.eval()
87
+ for x, y in tqdm(valid_loader):
88
+ x = x.type(torch.FloatTensor).to(device)
89
+ y = y.type(torch.FloatTensor).to(device)
90
+
91
+ predict = model(x)
92
+ loss = criterion(predict, y)
93
+ val_Loss += loss.item()
94
+
95
+ predict = torch.sigmoid(predict)
96
+ predict = (predict > 0.5).float()
97
+
98
+ dice_score = (2 * (y*predict).sum() + 1e-8)/((y+predict).sum() + 1e-8)
99
+ try:
100
+ val_Dicescore += dice_score.cpu().item()
101
+ except:
102
+ val_Dicescore += dice_score
103
+
104
+ val_Loss /= len(valid_loader)
105
+ val_Dicescore /= len(valid_loader)
106
+
107
+
108
+ logger.info(f"Loss: {train_Loss} - Dice Score: {train_Dicescore} - Validation Loss: {val_Loss} - Validation Dice Score: {val_Dicescore}")
109
+
110
+ if val_Dicescore > best_DICESCORE:
111
+ best_DICESCORE = val_Dicescore
112
+ torch.save(model, PATH)
113
+
114
+ summary['train_loss'] = train_Loss
115
+ summary['train_dice'] = train_Dicescore
116
+ summary['valid_loss'] = val_Loss
117
+ summary['valid_dice'] = val_Dicescore
118
+
119
+
120
+ return summary