Spaces:
baselqt
/
No application file

simswap55 / src /DataManager /ImageDataManager.py
LB5's picture
Upload 45 files
22b8701
from src.DataManager.base import BaseDataManager
from src.DataManager.utils import imread_rgb, imwrite_rgb
import numpy as np
from pathlib import Path
class ImageDataManager(BaseDataManager):
def __init__(self, src_data: Path, output_dir: Path):
self.output_dir: Path = output_dir
self.output_dir.mkdir(exist_ok=True)
self.output_dir = output_dir / "img"
self.output_dir.mkdir(exist_ok=True)
self.data_paths = []
if src_data.is_file():
self.data_paths.append(src_data)
elif src_data.is_dir():
self.data_paths = (
list(src_data.glob("*.jpg"))
+ list(src_data.glob("*.jpeg"))
+ list(src_data.glob("*.png"))
)
assert len(self.data_paths), "Data must be supplied!"
self.data_paths_iter = iter(self.data_paths)
self.last_idx = -1
def __len__(self):
return len(self.data_paths)
def get(self) -> np.ndarray:
img_path = next(self.data_paths_iter)
self.last_idx += 1
return imread_rgb(img_path)
def save(self, img: np.ndarray):
filename = "swap_" + Path(self.data_paths[self.last_idx]).name
imwrite_rgb(self.output_dir / filename, img)