Michael Rey commited on
Commit
b1205cc
·
1 Parent(s): 6abf0ea

moved a file

Browse files
Files changed (2) hide show
  1. src/data_loader.py +31 -0
  2. src/streamlit_app.py +1 -1
src/data_loader.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import datasets, transforms
2
+ from torch.utils.data import DataLoader
3
+
4
+ def get_data_loaders(data_dir, batch_size=32):
5
+ # Data augmentation + normalization for training
6
+ transform_train = transforms.Compose([
7
+ transforms.RandomResizedCrop(128),
8
+ transforms.RandomHorizontalFlip(),
9
+ transforms.RandomRotation(10),
10
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
13
+ ])
14
+
15
+ # Only resize + normalize for validation
16
+ transform_val = transforms.Compose([
17
+ transforms.Resize((128, 128)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
20
+ ])
21
+
22
+ train_dir = f"{data_dir}/training"
23
+ val_dir = f"{data_dir}/validation"
24
+
25
+ train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
26
+ val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)
27
+
28
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
29
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
30
+
31
+ return train_loader, val_loader, train_dataset.classes
src/streamlit_app.py CHANGED
@@ -6,7 +6,7 @@ import torchvision.transforms as transforms
6
  from PIL import Image
7
 
8
  from resnet_model import MonkeyResNet
9
- from utils.data_loader import get_data_loaders
10
 
11
  # Ensure the parent directory is in the system path for module imports
12
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
 
6
  from PIL import Image
7
 
8
  from resnet_model import MonkeyResNet
9
+ from data_loader import get_data_loaders
10
 
11
  # Ensure the parent directory is in the system path for module imports
12
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))