from jittor.dataset.mnist import MNIST import jittor.transform as transform from jittor.dataset.dataset import ImageFolder import jittor as jt from jittor import nn, Module import os import argparse from time import * import PIL.Image as Image import numpy as np import matplotlib.pyplot as plt plt.switch_backend('agg') jt.flags.use_cuda = 1 # 参数设定 parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型') parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') parser.add_argument('--eval_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') parser.add_argument('--n_epochs', type=int, default=100, help='训练的时期数') parser.add_argument('--batch_size', type=int, default=64, help='批次大小') parser.add_argument('--lr', type=float, default=0.0002, help='学习率') parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') parser.add_argument('--img_size', type=int, default=112, help='每个图像尺寸的大小') parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数') parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数') parser.add_argument('--img_row', type=int, default=5, help='图像样本之间的间隔') parser.add_argument('--img_column', type=int, default=5, help='图像样本之间的间隔') ''' parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') ''' opt = parser.parse_args() print(opt) # 训练集加载程序 def DataLoader(dataclass, img_size, batch_size, train_dir, eval_dir): if dataclass == 'MNIST': Transform = transform.Compose([ transform.Resize(size=img_size), transform.Gray(), transform.ImageNormalize(mean=[0.5], std=[0.5])]) train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True) eval_loader = MNIST (data_root=eval_dir, train=False, transform = Transform).set_attrs(batch_size=1, shuffle=True) elif dataclass == 'celebA': Transform = transform.Compose([ transform.Resize(size=img_size), transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])]) train_loader = ImageFolder(train_dir)\ .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) eval_loader = ImageFolder(eval_dir)\ .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) else: print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass) dataclass = input("请输入:MNIST或者celebA:") DataLoader(dataclass, img_size, batch_size,train_dir, eval_dir) return train_loader, eval_loader # 加载训练集数据 train_loader, eval_loader = DataLoader(dataclass=opt.task,img_size=opt.img_size,batch_size=opt.batch_size,train_dir=opt.train_dir,eval_dir=opt.eval_dir) # 生成器 class generator(Module): def __init__(self, dim=3): super(generator, self).__init__() self.fc = nn.Linear(1024, 7*7*256) self.fc_bn = nn.BatchNorm(256) self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) self.deconv1_bn = nn.BatchNorm(256) self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1) self.deconv2_bn = nn.BatchNorm(256) self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) self.deconv3_bn = nn.BatchNorm(256) self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1) self.deconv4_bn = nn.BatchNorm(256) self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1) self.deconv5_bn = nn.BatchNorm(128) self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1) self.deconv6_bn = nn.BatchNorm(64) self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1) self.relu = nn.ReLU() self.tanh = nn.Tanh() def execute(self, input): x = self.fc(input).reshape((-1, 256, 7, 7)) x = self.relu(self.fc_bn(x)) x = self.relu(self.deconv1_bn(self.deconv1(x))) x = self.relu(self.deconv2_bn(self.deconv2(x))) x = self.relu(self.deconv3_bn(self.deconv3(x))) x = self.relu(self.deconv4_bn(self.deconv4(x))) x = self.relu(self.deconv5_bn(self.deconv5(x))) x = self.relu(self.deconv6_bn(self.deconv6(x))) x = self.tanh(self.deconv7(x)) return x # 判别器 class discriminator(nn.Module): def __init__(self, dim=3): super(discriminator, self).__init__() self.conv1 = nn.Conv(dim, 64, 5, 2, 2) self.conv2 = nn.Conv(64, 128, 5, 2, 2) self.conv2_bn = nn.BatchNorm(128) self.conv3 = nn.Conv(128, 256, 5, 2, 2) self.conv3_bn = nn.BatchNorm(256) self.conv4 = nn.Conv(256, 512, 5, 2, 2) self.conv4_bn = nn.BatchNorm(512) self.fc = nn.Linear(512*7*7, 1) self.leaky_relu = nn.Leaky_relu() def execute(self, input): x = self.leaky_relu(self.conv1(input), 0.2) x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) x = x.reshape((x.shape[0], 512*7*7)) x = self.fc(x) return x # 损失函数 def ls_loss(x, b): mini_batch = x.shape[0] y_real_ = jt.ones((mini_batch,)) y_fake_ = jt.zeros((mini_batch,)) if b: return (x-y_real_).sqr().mean() else: return (x-y_fake_).sqr().mean() # 定义图像拼接函数 def image_compose(array,IMAGE_SIZE=128,IMAGE_SAVE_PATH='./images_celebA'): to_image = Image.new('RGB', (opt.img_column * IMAGE_SIZE, opt.img_row * IMAGE_SIZE)) # 创建一个新图 randomList = np.random.randint(0,array.shape[0],25) img_list = list() for i in randomList: # print(type(array[i])) img = Image.fromarray(np.uint8(array[i].transpose((1,2,0))*255)) img_list.append(img) # 循环遍历,把每张图片按顺序粘贴到对应位置上 for y in range(1, opt.img_row + 1): for x in range(1, opt.img_column + 1): from_image = img_list.pop().resize((IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS) to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE)) return to_image.save(IMAGE_SAVE_PATH) # 保存新图 def save_img_result(num_epoch, G, path = './images_celebA/result.png'): fixed_z_ = jt.init.gauss((5 * 5, 1024), 'float') # fixed noise z_ = fixed_z_ G.eval() test_images = G(z_) G.train() size_figure_grid = 5 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) for i in range(size_figure_grid): for j in range(size_figure_grid): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(5*5): i = k // 5 j = k % 5 ax[i, j].cla() if opt.task=="MNIST": ax[i, j].imshow((test_images[k, 0].data+1)/2, cmap='gray') else: ax[i, j].imshow((test_images[k].data.transpose(1, 2, 0)+1)/2) label = 'Epoch {0}'.format(num_epoch) fig.text(0.5, 0.04, label, ha='center') plt.savefig(path) plt.close() def train(epoch): for batch_idx, (x_, target) in enumerate(train_loader): mini_batch = x_.shape[0] # 判别器训练 将假图片尽可能的判别为0 D_result = D(x_) #输入[128,3,112,112,] 生成[128,1] 128位batch_size D_real_loss = ls_loss(D_result, True) #真实图片的损失 z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成随机噪声,大小为[128,1024] G_result = G(z_) #输入噪声,生成[128,3,112,112,] D_result_ = D(G_result) #输入由噪声生成的图像,得到判别器的预测值 D_fake_loss = ls_loss(D_result_, False) #假图片的损失 D_train_loss = D_real_loss + D_fake_loss D_train_loss.sync() D_optim.step(D_train_loss) # 生成器训练 让生成器尽可能的生成真实的照片 z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成噪声 G_result = G(z_) #由噪声生成假图片 D_result = D(G_result) #将假图片输入到判别器,得到预测值 G_train_loss = ls_loss(D_result, True) #将假图片的预测值与1做损失,目的是未来让生成器尽可能的生成真实的照片 G_train_loss.sync() G_optim.step(G_train_loss) if (batch_idx%100==0 ): print("train: epoch{} batch_idx{} D training loss = {} G training loss = {} ".format(epoch,batch_idx,D_train_loss.data.mean(),G_train_loss.data.mean())) # if((epoch)%5==0 or epoch==0 and batch_idx==100): # image_compose(G_result.data,128,"./imgs/epoch{}-G_{}.jpg".format(epoch,task)) def validate(epoch): D_losses = [] G_losses = [] G.eval() D.eval() for batch_idx, (x_, target) in enumerate(eval_loader): mini_batch = x_.shape[0] # 判别器损失计算 D_result = D(x_) D_real_loss = ls_loss(D_result, True) z_ = jt.init.gauss((mini_batch, 1024), 'float') G_result = G(z_) D_result_ = D(G_result) D_fake_loss = ls_loss(D_result_, False) D_train_loss = D_real_loss + D_fake_loss D_losses.append(D_train_loss.data.mean()) # 生成器损失计算 z_ = jt.init.gauss((mini_batch, 1024), 'float') G_result = G(z_) D_result = D(G_result) G_train_loss = ls_loss(D_result, True) G_losses.append(G_train_loss.data.mean()) G.train() D.train() print("validate: epoch{}\tbatch_idx{}\tD training loss = {}\tG training loss = {}" .format(epoch, batch_idx, str(np.array(D_losses).mean()), str(np.array(G_losses).mean()))) # 初始化生成器和判别器 (通道数) G = generator(opt.celebA_channels) D = discriminator(opt.celebA_channels) # 优化器 0.0002 (0.5, 0.999) G_optim = jt.nn.Adam(G.parameters(), opt.lr, betas=(opt.b1, opt.b2)) D_optim = jt.nn.Adam(D.parameters(), opt.lr, betas=(opt.b1, opt.b2)) # 结果存储地址 save_img_path = './images_celebA' save_model_path = './save_model_celebA' os.makedirs(save_img_path, exist_ok=True) os.makedirs(save_model_path, exist_ok=True) G.load_parameters(jt.load(save_model_path+'/generator_celebA.pkl')) D.load_parameters(jt.load(save_model_path+'/discriminator_celebA.pkl')) for epoch in range(37,opt.n_epochs): print ('number of epochs', epoch) train(epoch) #validate(epoch) result_img_path = save_img_path + '/' + str(epoch) + '.png' save_img_result(epoch, G, path=result_img_path) # 指定地址保存训练好的模型 if (epoch+1) % 10 == 0: G.save(save_model_path+"/generator_celebA.pkl") D.save(save_model_path+"/discriminator_celebA.pkl")