MRISegmentation / main.py
smishr-18's picture
Upload 30 files
a578142 verified
import os
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import torchsummary
import torchview
import config.configure as config
from src import logger
from src.data.data_ingestion import DataIngestion
from src.data.data_preprocess import data_loaders
from src.pipelines.training import model_fit
from src.model.unet import UNet
## graphviiz
STAGE_NAME = "Data Ingestion stage"
try:
logger.info(f">>>>>>>> Starting {STAGE_NAME} <<<<<<<<")
data_ingestion = DataIngestion()
data_ingestion.download()
except Exception as e:
logger.exception(e)
raise e
STAGE_NAME = 'Training'
BATCH_SIZE = 32
NUM_WORKERS = 3
EPOCHS = 50
PATH = config.SAVE_MODEL_PATH
try:
logger.info(f'Preparing DataLoders')
# getting the dataloaders
train_loader, valid_loader = data_loaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_split=True)
# fitting the model
loss_fn = nn.BCEWithLogitsLoss()
in_channels = 3
out_channels = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
features = [64, 128, 256, 512]
model = UNet(in_channels=in_channels, out_channels=out_channels, features=features)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)
# starting the training stage
logger.info(f"Strating {STAGE_NAME} Stage \n\n ==============")
summary = model_fit(
epochs=EPOCHS,
model=model,
device=device,
train_loader=train_loader,
valid_loader=valid_loader,
criterion=loss_fn,
optimizer=optimizer,
PATH=PATH
)
except Exception as e:
logger.exception(e)
raise e