Spaces:
baselqt
/
No application file

File size: 1,266 Bytes
22b8701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from src.DataManager.base import BaseDataManager
from src.DataManager.utils import imread_rgb, imwrite_rgb

import numpy as np
from pathlib import Path


class ImageDataManager(BaseDataManager):
    def __init__(self, src_data: Path, output_dir: Path):
        self.output_dir: Path = output_dir
        self.output_dir.mkdir(exist_ok=True)
        self.output_dir = output_dir / "img"
        self.output_dir.mkdir(exist_ok=True)

        self.data_paths = []
        if src_data.is_file():
            self.data_paths.append(src_data)
        elif src_data.is_dir():
            self.data_paths = (
                list(src_data.glob("*.jpg"))
                + list(src_data.glob("*.jpeg"))
                + list(src_data.glob("*.png"))
            )

        assert len(self.data_paths), "Data must be supplied!"

        self.data_paths_iter = iter(self.data_paths)

        self.last_idx = -1

    def __len__(self):
        return len(self.data_paths)

    def get(self) -> np.ndarray:
        img_path = next(self.data_paths_iter)
        self.last_idx += 1
        return imread_rgb(img_path)

    def save(self, img: np.ndarray):
        filename = "swap_" + Path(self.data_paths[self.last_idx]).name

        imwrite_rgb(self.output_dir / filename, img)