File size: 1,266 Bytes
22b8701 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from src.DataManager.base import BaseDataManager
from src.DataManager.utils import imread_rgb, imwrite_rgb
import numpy as np
from pathlib import Path
class ImageDataManager(BaseDataManager):
def __init__(self, src_data: Path, output_dir: Path):
self.output_dir: Path = output_dir
self.output_dir.mkdir(exist_ok=True)
self.output_dir = output_dir / "img"
self.output_dir.mkdir(exist_ok=True)
self.data_paths = []
if src_data.is_file():
self.data_paths.append(src_data)
elif src_data.is_dir():
self.data_paths = (
list(src_data.glob("*.jpg"))
+ list(src_data.glob("*.jpeg"))
+ list(src_data.glob("*.png"))
)
assert len(self.data_paths), "Data must be supplied!"
self.data_paths_iter = iter(self.data_paths)
self.last_idx = -1
def __len__(self):
return len(self.data_paths)
def get(self) -> np.ndarray:
img_path = next(self.data_paths_iter)
self.last_idx += 1
return imread_rgb(img_path)
def save(self, img: np.ndarray):
filename = "swap_" + Path(self.data_paths[self.last_idx]).name
imwrite_rgb(self.output_dir / filename, img)
|