GabrielML's picture
Fixes
8850972
raw
history blame
2.84 kB
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from PIL import Image
import torch
import imagehash
ImageCache = None
class AnimalDataset(Dataset):
def __init__(self, df, transform=None):
self.paths = df["path"].values
self.targets = df["target"].values
self.encoded_target = df['encoded_target'].values
self.transform = transform
self.images = []
for path in tqdm(self.paths):
self.images.append(Image.open(path).convert("RGB").resize((224, 224)))
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = self.images[idx]
if self.transform:
img = self.transform(img)
target = self.targets[idx]
encoded_target = torch.tensor(self.encoded_target[idx]).type(torch.LongTensor)
return img, encoded_target, target
train_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the transformation pipeline
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(), # Convert the images to PyTorch tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class CustomImageCache:
def __init__(self, cache_size=50, debug=False):
self.cache = dict()
self.cache_size = 50
self.debug = debug
self.cache_hits = 0
self.cache_misses = 0
def __getitem__(self, image):
if isinstance(image, dict):
# Its the image and a mask as pillow both -> Combine them to one image
image = Image.blend(image["image"], image["mask"], alpha=0.5)
key = imagehash.average_hash(image)
if key in self.cache:
if self.debug: print("Cache hit!")
self.cache_hits += 1
return self.cache[key]
else:
if self.debug: print("Cache miss!")
self.cache_misses += 1
if len(self.cache.keys()) >= self.cache_size:
if self.debug: print("Cache full, popping item!")
self.cache.popitem()
self.cache[key] = image
return self.cache[key]
def __len__(self):
return len(self.cache.keys())
def print_info(self):
print(f"Cache size: {len(self)}")
print(f"Cache hits: {self.cache_hits}")
print(f"Cache misses: {self.cache_misses}")
def imageCacheWrapper(fn):
def wrapper(image):
return fn(ImageCache[image])
return wrapper