Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch | |
from models.networks.base_network import BaseNetwork | |
from models.networks.generator import * | |
from models.networks.encoder import * | |
import util.util as util | |
def find_network_using_name(target_network_name, filename): | |
target_class_name = target_network_name + filename | |
module_name = "models.networks." + filename | |
network = util.find_class_in_module(target_class_name, module_name) | |
assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network | |
return network | |
def modify_commandline_options(parser, is_train): | |
opt, _ = parser.parse_known_args() | |
netG_cls = find_network_using_name(opt.netG, "generator") | |
parser = netG_cls.modify_commandline_options(parser, is_train) | |
if is_train: | |
netD_cls = find_network_using_name(opt.netD, "discriminator") | |
parser = netD_cls.modify_commandline_options(parser, is_train) | |
netE_cls = find_network_using_name("conv", "encoder") | |
parser = netE_cls.modify_commandline_options(parser, is_train) | |
return parser | |
def create_network(cls, opt): | |
net = cls(opt) | |
net.print_network() | |
if len(opt.gpu_ids) > 0: | |
assert torch.cuda.is_available() | |
net.cuda() | |
net.init_weights(opt.init_type, opt.init_variance) | |
return net | |
def define_G(opt): | |
netG_cls = find_network_using_name(opt.netG, "generator") | |
return create_network(netG_cls, opt) | |
def define_D(opt): | |
netD_cls = find_network_using_name(opt.netD, "discriminator") | |
return create_network(netD_cls, opt) | |
def define_E(opt): | |
# there exists only one encoder type | |
netE_cls = find_network_using_name("conv", "encoder") | |
return create_network(netE_cls, opt) | |