Spaces:
Configuration error
Configuration error
from src.DataManager.base import BaseDataManager | |
from src.DataManager.utils import imread_rgb, imwrite_rgb | |
import numpy as np | |
from pathlib import Path | |
class ImageDataManager(BaseDataManager): | |
def __init__(self, src_data: Path, output_dir: Path): | |
self.output_dir: Path = output_dir | |
self.output_dir.mkdir(exist_ok=True) | |
self.output_dir = output_dir / "img" | |
self.output_dir.mkdir(exist_ok=True) | |
self.data_paths = [] | |
if src_data.is_file(): | |
self.data_paths.append(src_data) | |
elif src_data.is_dir(): | |
self.data_paths = ( | |
list(src_data.glob("*.jpg")) | |
+ list(src_data.glob("*.jpeg")) | |
+ list(src_data.glob("*.png")) | |
) | |
assert len(self.data_paths), "Data must be supplied!" | |
self.data_paths_iter = iter(self.data_paths) | |
self.last_idx = -1 | |
def __len__(self): | |
return len(self.data_paths) | |
def get(self) -> np.ndarray: | |
img_path = next(self.data_paths_iter) | |
self.last_idx += 1 | |
return imread_rgb(img_path) | |
def save(self, img: np.ndarray): | |
filename = "swap_" + Path(self.data_paths[self.last_idx]).name | |
imwrite_rgb(self.output_dir / filename, img) | |