File size: 3,334 Bytes
6709fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e28b9f2
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

import sys
import os
from importlib import import_module
from options import Settings
import csv 

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


"""
    This function modified from the Genforce library: https://github.com/genforce/genforce
"""
def parse_config(config_file):
    """Parses configuration from python file."""
    assert os.path.isfile(config_file)
    directory = os.path.dirname(config_file)
    filename = os.path.basename(config_file)
    module_name, extension = os.path.splitext(filename)
    assert extension == '.py'
    sys.path.insert(0, directory)
    module = import_module(module_name)
    sys.path.pop(0)
    config = []
    
    for key, value in module.__dict__.items():
        if key.startswith('__'):
            continue
        for val in value:
            attr_dict = AttrDict()
            for k, v in val.items():
                attr_dict[k] = v
            config.append(attr_dict)
    del sys.modules[module_name]
    return config

# Utility class for the demo
class AppUtils():
    def __init__(self):
        self.interfacegan_edits = ['Smile', 'Age' , 'Pose']
        self.ganspace_edits = []
        with open(os.path.join(Settings.ganspace_directions, 'ganspace_configs.csv'), "r") as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                key = row.pop(0)
                key = key.replace('_', ' ')
                self.ganspace_edits.append(key.title())
        self.ganspace_edits.sort()

        self.styleclip_edits = []
        with open(os.path.join(Settings.styleclip_settings, 'styleclip_mapping_configs.csv'), "r") as f:
            reader = csv.reader(f)
            for row in reader:
                key = row.pop(0)
                key = key.replace('_', ' ')
                self.styleclip_edits.append(key.title())
        self.styleclip_edits.sort()

    def get_methods(self):
        return ["InterfaceGAN", "GANSpace", "StyleClip"]

    def get_edits(self, method):
        method = method.lower()
        return getattr(self, f"{method}_edits")

    def args_to_cfg(self, method, edit, strength):
        method = method.lower()
        edit = edit.lower()
        edit = edit.replace(' ', '_')
        strength = float(strength)
        cfg = AttrDict()
        cfg.method = method
        cfg.edit = edit
        cfg.strength = strength
        if method == 'styleclip':
            cfg.type = 'mapper'
        return cfg

    def get_range(self, method):
        method = method.lower()
        if method == 'interfacegan':
            return -5, 5, 0.1
        elif method == 'ganspace':
            return -25, 25, 0.1
        elif method == 'styleclip':
            return 0, 0.2, 0.01

    def get_examples(self):
        examples = [
            ["samples/demo_samples/102.jpg", "InterfaceGAN",  "Age", 2.0,  False],
            ["samples/demo_samples/116.jpg", "Ganspace",  "lipstick", 10.0,  False],
            ["samples/demo_samples/carlsen.jpg", "Styleclip", "curly hair", 0.11, True],
            ["samples/demo_samples/shakira.jpeg", "StyleClip",  "purple hair",  0.1,  True],
            ["samples/demo_samples/shaq.jpg", "InterfaceGAN",  "Smile", -1.7,  True],
        ]
        return examples