Dmtlant commited on
Commit
1620753
1 Parent(s): afedc98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.utils as vutils
6
+ import torchvision.datasets as dset
7
+ import torchvision.transforms as transforms
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import os
12
+
13
+ # Параметры
14
+ nc = 3 # Количество каналов в изображении
15
+ nz = 100 # Размер вектора шума
16
+ ngf = 64 # Размер карт признаков генератора
17
+ ndf = 64 # Размер карт признаков дискриминатора
18
+ num_epochs = 5 # Количество эпох обучения
19
+ lr = 0.0002 # Скорость обучения
20
+ beta1 = 0.5 # Beta1 для Adam оптимизатора
21
+ batch_size = 64 # Размер батча
22
+ image_size = 64 # Размер изображения
23
+
24
+ # Генератор
25
+ class Generator(nn.Module):
26
+ def __init__(self):
27
+ super(Generator, self).__init__()
28
+ self.main = nn.Sequential(
29
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
30
+ nn.BatchNorm2d(ngf * 8),
31
+ nn.ReLU(True),
32
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
33
+ nn.BatchNorm2d(ngf * 4),
34
+ nn.ReLU(True),
35
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
36
+ nn.BatchNorm2d(ngf * 2),
37
+ nn.ReLU(True),
38
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
39
+ nn.BatchNorm2d(ngf),
40
+ nn.ReLU(True),
41
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
42
+ nn.Tanh()
43
+ )
44
+
45
+ def forward(self, input):
46
+ return self.main(input)
47
+
48
+ # Дискриминатор
49
+ class Discriminator(nn.Module):
50
+ def __init__(self):
51
+ super(Discriminator, self).__init__()
52
+ self.main = nn.Sequential(
53
+ nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
54
+ nn.LeakyReLU(0.2, inplace=True),
55
+ nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
56
+ nn.BatchNorm2d(ndf * 2),
57
+ nn.LeakyReLU(0.2, inplace=True),
58
+ nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
59
+ nn.BatchNorm2d(ndf * 4),
60
+ nn.LeakyReLU(0.2, inplace=True),
61
+ nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
62
+ nn.BatchNorm2d(ndf * 8),
63
+ nn.LeakyReLU(0.2, inplace=True),
64
+ nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
65
+ nn.Sigmoid()
66
+ )
67
+
68
+ def forward(self, input):
69
+ return self.main(input)
70
+
71
+ # Настройка Streamlit
72
+ st.title('DCGAN Training and Generation')
73
+
74
+ # Создание боковой панели
75
+ st.sidebar.title('Параметры')
76
+ num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5)
77
+ batch_size = st.sidebar.slider('Размер батча', 16, 128, 64)
78
+ lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002)
79
+
80
+ # Загрузка данных
81
+ @st.cache_data
82
+ def load_data():
83
+ dataset = dset.CIFAR10(root='./data', download=True,
84
+ transform=transforms.Compose([
85
+ transforms.Resize(image_size),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
88
+ ]))
89
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
90
+ shuffle=True, num_workers=2)
91
+ return dataloader
92
+
93
+ # Функция для визуализации результатов
94
+ def plot_training_results(G_losses, D_losses):
95
+ fig, ax = plt.subplots(figsize=(10, 5))
96
+ plt.plot(G_losses, label='Generator Loss')
97
+ plt.plot(D_losses, label='Discriminator Loss')
98
+ plt.xlabel('Iterations')
99
+ plt.ylabel('Loss')
100
+ plt.legend()
101
+ st.pyplot(fig)
102
+
103
+ # Функция генерации изображений
104
+ def generate_images(netG, num_images=64):
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ with torch.no_grad():
107
+ noise = torch.randn(num_images, nz, 1, 1, device=device)
108
+ fake = netG(noise).detach().cpu()
109
+ img = vutils.make_grid(fake, padding=2, normalize=True)
110
+ img = np.transpose(img, (1, 2, 0))
111
+ return img
112
+
113
+ # Функция обучения
114
+ def train_model():
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ st.write(f"Using device: {device}")
117
+
118
+ # Создание сетей
119
+ netG = Generator().to(device)
120
+ netD = Discriminator().to(device)
121
+
122
+ # Критерий и оптимизаторы
123
+ criterion = nn.BCELoss()
124
+ optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
125
+ optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
126
+
127
+ # Загрузка данных
128
+ dataloader = load_data()
129
+
130
+ # Списки для отслеживания прогресса
131
+ G_losses = []
132
+ D_losses = []
133
+
134
+ # Прогресс бар
135
+ progress_bar = st.progress(0)
136
+ status_text = st.empty()
137
+
138
+ # Обучение
139
+ for epoch in range(num_epochs):
140
+ for i, data in enumerate(dataloader, 0):
141
+ ############################
142
+ # (1) Обновление дискриминатора
143
+ ###########################
144
+ netD.zero_grad()
145
+ real = data[0].to(device)
146
+ b_size = real.size(0)
147
+ label = torch.full((b_size,), 1, dtype=torch.float, device=device)
148
+ output = netD(real).view(-1)
149
+ errD_real = criterion(output, label)
150
+ errD_real.backward()
151
+
152
+ noise = torch.randn(b_size, nz, 1, 1, device=device)
153
+ fake = netG(noise)
154
+ label.fill_(0)
155
+ output = netD(fake.detach()).view(-1)
156
+ errD_fake = criterion(output, label)
157
+ errD_fake.backward()
158
+ errD = errD_real + errD_fake
159
+ optimizerD.step()
160
+
161
+ ############################
162
+ # (2) Обновление генератора
163
+ ###########################
164
+ netG.zero_grad()
165
+ label.fill_(1)
166
+ output = netD(fake).view(-1)
167
+ errG = criterion(output, label)
168
+ errG.backward()
169
+ optimizerG.step()
170
+
171
+ # Обновление статуса
172
+ if i % 100 == 0:
173
+ status_text.text(f'Эпоха [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
174
+ f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}')
175
+
176
+ G_losses.append(errG.item())
177
+ D_losses.append(errD.item())
178
+
179
+ # Показать промежуточные результаты
180
+ if i % 500 == 0:
181
+ with torch.no_grad():
182
+ fake = netG(torch.randn(64, nz, 1, 1, device=device)).detach().cpu()
183
+ img = vutils.make_grid(fake, padding=2, normalize=True)
184
+ img = np.transpose(img, (1, 2, 0))
185
+ st.image(img, caption=f'Эпоха {epoch}, Batch {i}')
186
+
187
+ # Обновление прогресс бара
188
+ progress_bar.progress((epoch + 1) / num_epochs)
189
+
190
+ # Сохранение модели
191
+ torch.save(netG.state_dict(), 'generator.pth')
192
+ return netG, G_losses, D_losses
193
+
194
+ # Основной интерфейс Streamlit
195
+ def main():
196
+ st.sidebar.title('DCGAN Control Panel')
197
+
198
+ # Выбор режима
199
+ mode = st.sidebar.selectbox('Выберите режим',
200
+ ['Обучение', 'Генерация'])
201
+
202
+ if mode == 'Обучение':
203
+ if st.button('Начать обучение'):
204
+ st.write('Начинаем обучение...')
205
+ netG, G_losses, D_losses = train_model()
206
+ st.write('Обучение завершено!')
207
+
208
+ # Отображение графиков потерь
209
+ st.subheader('Графики потерь')
210
+ plot_training_results(G_losses, D_losses)
211
+
212
+ # Генерация финальных изображений
213
+ st.subheader('Финальные сгенерированные изображения')
214
+ final_images = generate_images(netG)
215
+ st.image(final_images, caption='Финальные сгенерированные изображения')
216
+
217
+ elif mode == 'Генерация':
218
+ if os.path.exists('generator.pth'):
219
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
220
+ netG = Generator().to(device)
221
+ netG.load_state_dict(torch.load('generator.pth', map_location=device))
222
+
223
+ num_images = st.slider('Количество изображений', 1, 64, 16)
224
+
225
+ if st.button('Сгенерировать изображения'):
226
+ images = generate_images(netG, num_images)
227
+ st.image(images, caption='Сгенерированные изображения')
228
+
229
+ # Опция сохранения
230
+ if st.button('Сохранить изображения'):
231
+ im = Image.fromarray((images * 255).astype(np.uint8))
232
+ im.save('generated_images.png')
233
+ st.success('Изображения сохранены!')
234
+ else:
235
+ st.error('Модель не найдена. Пожалуйста, сначала обучите модель.')
236
+
237
+ # Запуск приложения
238
+ if __name__ == '__main__':
239
+ main()
240
+
241
+ # Дополнительные настройки
242
+ st.sidebar.markdown("""
243
+ ## О проекте
244
+ Это приложение демонстрирует работу DCGAN (Deep Convolutional Generative Adversarial Network)
245
+ для генерации изображений.
246
+
247
+ ### Особенности:
248
+ - Обучение на датасете CIFAR-10
249
+ - Генерация изображений 64x64
250
+ - Возможность настройки параметров
251
+ - Визуализация процесса обучения
252
+ """)
253
+
254
+ # Настройки кэширования
255
+ if st.sidebar.checkbox('Очистить кэш'):
256
+ st.caching.clear_cache()
257
+ st.success('Кэш очищен!')
258
+
259
+ # Дополнительные метрики
260
+ if st.sidebar.checkbox('Показать дополнительные метрики'):
261
+ st.sidebar.write(f'Размер батча: {batch_size}')
262
+ st.sidebar.write(f'Количество эпох: {num_epochs}')
263
+ st.sidebar.write(f'Скорость обучения: {lr}')