|
import torch |
|
|
|
from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet |
|
from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost |
|
|
|
|
|
def create_wgan(parameters, device, optimizer='adam'): |
|
if parameters['model'] == "resnet": |
|
generator, discriminator = init_resnet(parameters) |
|
else: |
|
raise NotImplementedError |
|
|
|
if optimizer == 'adam': |
|
optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) |
|
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) |
|
elif optimizer == 'rmsprop': |
|
optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) |
|
optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) |
|
|
|
criterion = torch.nn.MSELoss() |
|
|
|
gan = WassersteinGanQuadraticCost(generator, |
|
discriminator, |
|
optimizer_g, |
|
optimizer_d, |
|
criterion=criterion, |
|
data_dimensions=parameters['data_dim'], |
|
epochs=parameters['epochs'], |
|
batch_size=parameters['batch_size'], |
|
device=device, |
|
n_max_iterations=parameters['n_max_iterations'], |
|
gamma=parameters['gamma']) |
|
|
|
return gan |
|
|