SpiralSense / configs.py
cycool29's picture
Update
73666ad
raw
history blame
1.48 kB
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from models import *
# Constants
RANDOM_SEED = 123
BATCH_SIZE = 8
NUM_EPOCHS = 150
WARMUP_EPOCHS = 5
LEARNING_RATE = 0.0001
STEP_SIZE = 10
GAMMA = 0.3
CUTMIX_ALPHA = 0.3
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")
NUM_PRINT = 100
TASK = 1
WARMUP_EPOCHS = 5
RAW_DATA_DIR = r"data/train/raw/Task "
AUG_DATA_DIR = r"data/train/augmented/Task "
EXTERNAL_DATA_DIR = r"data/train/external/Task "
COMBINED_DATA_DIR = r"data/train/combined/Task "
TEST_DATA_DIR = r"data/test/Task "
TEMP_DATA_DIR = "data/temp/Task "
NUM_CLASSES = 7
LABEL_SMOOTHING_EPSILON = 0.1
EARLY_STOPPING_PATIENCE = 20
CLASSES = [
"Alzheimer Disease",
"Cerebral Palsy",
"Dystonia",
"Essential Tremor",
"Healthy",
"Huntington Disease",
"Parkinson Disease",
]
MODEL = EfficientNetB3WithNorm(num_classes=NUM_CLASSES)
MODEL_SAVE_PATH = r"output/checkpoints/" + MODEL.__class__.__name__ + ".pth"
preprocess = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(0.8289, 0.2006),
]
)
# Custom dataset class
class CustomDataset(Dataset):
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img, label = self.data[idx]
return img, label