isLandLZ commited on
Commit
71fd799
·
1 Parent(s): b33f5cd

Upload gan_model_load_test.py

Browse files
Files changed (1) hide show
  1. gan_model_load_test.py +101 -0
gan_model_load_test.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jittor as jt
2
+ from jittor import init
3
+ from jittor import nn
4
+
5
+ import argparse
6
+ import numpy as np
7
+ import cv2
8
+
9
+ jt.flags.use_cuda = 1
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
13
+ parser.add_argument('--batch_size', type=int, default=64, help='批次大小')
14
+ parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
15
+ parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
16
+ parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
17
+ parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
18
+ parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
19
+ parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
20
+ parser.add_argument('--channels', type=int, default=1, help='图像通道数')
21
+ parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
22
+
23
+ opt = parser.parse_args()
24
+ print(opt)
25
+ img_shape = (opt.channels, opt.img_size, opt.img_size)
26
+
27
+ # 生成器
28
+ class Generator(nn.Module):
29
+
30
+ def __init__(self):
31
+ super(Generator, self).__init__()
32
+
33
+ def block(in_feat, out_feat, normalize=True):
34
+ layers = [nn.Linear(in_feat, out_feat)]
35
+ if normalize:
36
+ layers.append(nn.BatchNorm1d(out_feat, 0.8))
37
+ layers.append(nn.LeakyReLU(scale=0.2))
38
+ return layers
39
+ self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())
40
+
41
+ def execute(self, z):
42
+ img = self.model(z)
43
+ img = img.view((img.shape[0], *img_shape))
44
+ return img
45
+
46
+ # 判别器
47
+ class Discriminator(nn.Module):
48
+
49
+ def __init__(self):
50
+ super(Discriminator, self).__init__()
51
+ self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(scale=0.2), nn.Linear(512, 256), nn.LeakyReLU(scale=0.2), nn.Linear(256, 1), nn.Sigmoid())
52
+
53
+ def execute(self, img):
54
+ img_flat = img.view((img.shape[0], (- 1)))
55
+ validity = self.model(img_flat)
56
+ return validity
57
+
58
+ def deal_image(img, path=None, nrow=None):
59
+ N,C,W,H = img.shape# (25, 1, 28, 28)
60
+ '''
61
+ [-1,700,28] , img2的形状(1,700,28)
62
+ img[0][0][0] = img2[0][0]
63
+ img2:[
64
+ [1*28]
65
+ ......(一共700个)
66
+ ](1,700,28)
67
+ '''
68
+ img2=img.reshape([-1,W*nrow*nrow,H])
69
+ # [:,:28*5,:],img:(1,140,28)
70
+ img=img2[:,:W*nrow,:]
71
+ for i in range(1,nrow):#[1,5)
72
+ '''
73
+ img(1,140,28),img2(1,700,28)
74
+ img从(1,140,28)->(1,140,28+28)->...->(1,140,28+28+28+28)=(1,140,140)
75
+ np.concatenate把两个三维数组合并
76
+ '''
77
+ img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
78
+ # img中的数据大小从(-1,1)--(+1)-->(0,2)--(/2)-->(0,1)--(*255)-->(0,255)转换成了像素值
79
+ img=(img+1.0)/2.0*255
80
+ # (1,140,140)--->(140,140,1)
81
+ # (channels通道数,imagesize,imagesize)转化为(imagesize,imagesize,channels通道数)
82
+ img=img.transpose((1,2,0))
83
+ if path:
84
+ # 根据地址保存图片样本数据
85
+ cv2.imwrite(path,img)
86
+ cv2.imshow('1',img)
87
+ cv2.waitKey(0)
88
+
89
+ # 初始化生成器和判别器,并加载模型
90
+ generator = Generator()
91
+ g_model_path = "saved_models/generator_last.pkl"
92
+ generator.load_parameters(jt.load(g_model_path))
93
+ generator.load(g_model_path)
94
+ discriminator = Discriminator()
95
+ d_model_path = "saved_models/discriminator_last.pkl"
96
+ discriminator.load_parameters(jt.load(d_model_path))
97
+ discriminator.load(d_model_path)
98
+
99
+ z = jt.array(np.random.normal(0, 1, (64, opt.latent_dim)).astype(np.float32))
100
+ gen_imgs = generator(z)
101
+ deal_image(gen_imgs.data[:25], "images_test/1.png",nrow=5)