msong97 commited on
Commit
5b3128d
·
1 Parent(s): 128be65

Adds common natural img and remove previous

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. datasets.py +2 -2
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.pth.tar filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.pth.tar filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.JPEG filter=lfs diff=lfs merge=lfs -text
datasets.py CHANGED
@@ -93,7 +93,7 @@ class LsdirMiniDataset(torch.utils.data.Dataset):
93
  transform: Optional[Callable] = None,
94
  ) -> None:
95
  self.root = root
96
- self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith('.png')]
97
  self.transform = transform
98
 
99
  def __len__(self) -> int:
@@ -101,7 +101,7 @@ class LsdirMiniDataset(torch.utils.data.Dataset):
101
 
102
  def __getitem__(self, idx):
103
  img_path = os.path.join(self.root, self.image_files[idx])
104
- img = Image.open(img_path)
105
  if self.transform:
106
  img = self.transform(img)
107
 
 
93
  transform: Optional[Callable] = None,
94
  ) -> None:
95
  self.root = root
96
+ self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.JPEG'))]
97
  self.transform = transform
98
 
99
  def __len__(self) -> int:
 
101
 
102
  def __getitem__(self, idx):
103
  img_path = os.path.join(self.root, self.image_files[idx])
104
+ img = Image.open(img_path).convert("RGB") # Ensure consistent 3-channel format
105
  if self.transform:
106
  img = self.transform(img)
107