from src.DataManager.base import BaseDataManager from src.DataManager.utils import imread_rgb, imwrite_rgb import numpy as np from pathlib import Path class ImageDataManager(BaseDataManager): def __init__(self, src_data: Path, output_dir: Path): self.output_dir: Path = output_dir self.output_dir.mkdir(exist_ok=True) self.output_dir = output_dir / "img" self.output_dir.mkdir(exist_ok=True) self.data_paths = [] if src_data.is_file(): self.data_paths.append(src_data) elif src_data.is_dir(): self.data_paths = ( list(src_data.glob("*.jpg")) + list(src_data.glob("*.jpeg")) + list(src_data.glob("*.png")) ) assert len(self.data_paths), "Data must be supplied!" self.data_paths_iter = iter(self.data_paths) self.last_idx = -1 def __len__(self): return len(self.data_paths) def get(self) -> np.ndarray: img_path = next(self.data_paths_iter) self.last_idx += 1 return imread_rgb(img_path) def save(self, img: np.ndarray): filename = "swap_" + Path(self.data_paths[self.last_idx]).name imwrite_rgb(self.output_dir / filename, img)