Spaces:
Sleeping
Sleeping
import os | |
import torch.nn as nn | |
import torch | |
import matplotlib.pyplot as plt | |
import torchsummary | |
import torchview | |
import config.configure as config | |
from src import logger | |
from src.data.data_ingestion import DataIngestion | |
from src.data.data_preprocess import data_loaders | |
from src.pipelines.training import model_fit | |
from src.model.unet import UNet | |
## graphviiz | |
STAGE_NAME = "Data Ingestion stage" | |
try: | |
logger.info(f">>>>>>>> Starting {STAGE_NAME} <<<<<<<<") | |
data_ingestion = DataIngestion() | |
data_ingestion.download() | |
except Exception as e: | |
logger.exception(e) | |
raise e | |
STAGE_NAME = 'Training' | |
BATCH_SIZE = 32 | |
NUM_WORKERS = 3 | |
EPOCHS = 50 | |
PATH = config.SAVE_MODEL_PATH | |
try: | |
logger.info(f'Preparing DataLoders') | |
# getting the dataloaders | |
train_loader, valid_loader = data_loaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, train_split=True) | |
# fitting the model | |
loss_fn = nn.BCEWithLogitsLoss() | |
in_channels = 3 | |
out_channels = 1 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
features = [64, 128, 256, 512] | |
model = UNet(in_channels=in_channels, out_channels=out_channels, features=features) | |
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4) | |
# starting the training stage | |
logger.info(f"Strating {STAGE_NAME} Stage \n\n ==============") | |
summary = model_fit( | |
epochs=EPOCHS, | |
model=model, | |
device=device, | |
train_loader=train_loader, | |
valid_loader=valid_loader, | |
criterion=loss_fn, | |
optimizer=optimizer, | |
PATH=PATH | |
) | |
except Exception as e: | |
logger.exception(e) | |
raise e |