Spaces:
baselqt
/
No application file

File size: 2,386 Bytes
e6a22e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
from typing import Optional
from tqdm import tqdm

import hydra
from omegaconf import DictConfig
import numpy as np

from src.simswap import SimSwap
from src.DataManager.ImageDataManager import ImageDataManager
from src.DataManager.VideoDataManager import VideoDataManager
from src.DataManager.utils import imread_rgb


class Application:
    def __init__(self, config: DictConfig):

        id_image_path = Path(config.data.id_image)
        specific_id_image_path = Path(config.data.specific_id_image)
        att_image_path = Path(config.data.att_image)
        att_video_path = Path(config.data.att_video)
        output_dir = Path(config.data.output_dir)

        assert id_image_path.exists(), f"Can't find {id_image_path} file!"

        self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path)
        self.specific_id_image: Optional[np.ndarray] = (
            imread_rgb(specific_id_image_path)
            if specific_id_image_path and specific_id_image_path.is_file()
            else None
        )

        self.att_image: Optional[ImageDataManager] = None
        if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()):
            self.att_image: Optional[ImageDataManager] = ImageDataManager(
                src_data=att_image_path, output_dir=output_dir
            )

        self.att_video: Optional[VideoDataManager] = None
        if att_video_path and att_video_path.is_file():
            self.att_video: Optional[VideoDataManager] = VideoDataManager(
                src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir
            )

        assert not (self.att_video and self.att_image), "Only one attribute source can be used!"

        self.data_manager = self.att_video if self.att_video else self.att_image

        self.model = SimSwap(
            config=config.pipeline,
            id_image=self.id_image,
            specific_image=self.specific_id_image,
        )

    def run(self):
        for _ in tqdm(range(len(self.data_manager))):

            att_img = self.data_manager.get()

            output = self.model(att_img)

            self.data_manager.save(output)


@hydra.main(config_path="configs/", config_name="run_image.yaml")
def main(config: DictConfig):

    app = Application(config)

    app.run()


if __name__ == "__main__":
    main()