Dmtlant commited on
Commit
5b29361
·
verified ·
1 Parent(s): 90e3043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -36
app.py CHANGED
@@ -9,6 +9,13 @@ import numpy as np
9
  import matplotlib.pyplot as plt
10
  from PIL import Image
11
  import os
 
 
 
 
 
 
 
12
 
13
  # Параметры
14
  nc = 3 # Количество каналов в изображении
@@ -68,14 +75,16 @@ class Discriminator(nn.Module):
68
  def forward(self, input):
69
  return self.main(input)
70
 
71
- # Настройка Streamlit
72
- st.title('DCGAN Training and Generation')
73
-
74
- # Создание боковой панели
75
- st.sidebar.title('Параметры')
76
- num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5)
77
- batch_size = st.sidebar.slider('Размер батча', 16, 128, 64)
78
- lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002)
 
 
79
 
80
  # Загрузка данных
81
  @st.cache_data
@@ -106,9 +115,7 @@ def generate_images(netG, num_images=64):
106
  with torch.no_grad():
107
  noise = torch.randn(num_images, nz, 1, 1, device=device)
108
  fake = netG(noise).detach().cpu()
109
- img = vutils.make_grid(fake, padding=2, normalize=True)
110
- img = np.transpose(img, (1, 2, 0))
111
- return img
112
 
113
  # Функция обучения
114
  def train_model():
@@ -158,7 +165,7 @@ def train_model():
158
  errD = errD_real + errD_fake
159
  optimizerD.step()
160
 
161
- ############################
162
  # (2) Обновление генератора
163
  ###########################
164
  netG.zero_grad()
@@ -178,11 +185,9 @@ def train_model():
178
 
179
  # Показать промежуточные результаты
180
  if i % 500 == 0:
181
- with torch.no_grad():
182
- fake = netG(torch.randn(64, nz, 1, 1, device=device)).detach().cpu()
183
- img = vutils.make_grid(fake, padding=2, normalize=True)
184
- img = np.transpose(img, (1, 2, 0))
185
- st.image(img, caption=f'Эпоха {epoch}, Batch {i}')
186
 
187
  # Обновление прогресс бара
188
  progress_bar.progress((epoch + 1) / num_epochs)
@@ -193,7 +198,14 @@ def train_model():
193
 
194
  # Основной интерфейс Streamlit
195
  def main():
196
- st.sidebar.title('DCGAN Control Panel')
 
 
 
 
 
 
 
197
 
198
  # Выбор режима
199
  mode = st.sidebar.selectbox('Выберите режим',
@@ -212,7 +224,8 @@ def main():
212
  # Генерация финальных изображений
213
  st.subheader('Финальные сгенерированные изображения')
214
  final_images = generate_images(netG)
215
- st.image(final_images, caption='Финальные сгенерированные изображения')
 
216
 
217
  elif mode == 'Генерация':
218
  if os.path.exists('generator.pth'):
@@ -224,20 +237,16 @@ def main():
224
 
225
  if st.button('Сгенерировать изображения'):
226
  images = generate_images(netG, num_images)
227
- st.image(images, caption='Сгенерированные изображения')
228
-
229
- # Опция сохранения
230
- if st.button('Сохранить изображения'):
231
- im = Image.fromarray((images * 255).astype(np.uint8))
232
- im.save('generated_images.png')
233
- st.success('Изображения сохранены!')
234
  else:
235
  st.error('Модель не найдена. Пожалуйста, сначала обучите модель.')
236
 
237
- # Запуск приложения
238
- if __name__ == '__main__':
239
- main()
240
-
241
  # Дополнительные настройки
242
  st.sidebar.markdown("""
243
  ## О проекте
@@ -253,11 +262,9 @@ st.sidebar.markdown("""
253
 
254
  # Настройки кэширования
255
  if st.sidebar.checkbox('Очистить кэш'):
256
- st.caching.clear_cache()
257
  st.success('Кэш очищен!')
258
 
259
- # Дополнительные метрики
260
- if st.sidebar.checkbox('Показать дополнительные метрики'):
261
- st.sidebar.write(f'Размер батча: {batch_size}')
262
- st.sidebar.write(f'Количество эпох: {num_epochs}')
263
- st.sidebar.write(f'Скорость обучения: {lr}')
 
9
  import matplotlib.pyplot as plt
10
  from PIL import Image
11
  import os
12
+ import sys
13
+
14
+ # Проверка правильного запуска
15
+ if not 'streamlit' in sys.modules:
16
+ print("Пожалуйста, запустите приложение с помощью команды:")
17
+ print("streamlit run dcgan_app.py")
18
+ sys.exit(1)
19
 
20
  # Параметры
21
  nc = 3 # Количество каналов в изображении
 
75
  def forward(self, input):
76
  return self.main(input)
77
 
78
+ # Функция для обработки тензора в изображение
79
+ def process_tensor_to_image(tensor):
80
+ try:
81
+ img = vutils.make_grid(tensor, padding=2, normalize=True)
82
+ img = np.transpose(img, (1, 2, 0))
83
+ img = ((img + 1) * 127.5).astype(np.uint8)
84
+ return Image.fromarray(img)
85
+ except Exception as e:
86
+ st.error(f"Ошибка при обработке изображения: {e}")
87
+ return None
88
 
89
  # Загрузка данных
90
  @st.cache_data
 
115
  with torch.no_grad():
116
  noise = torch.randn(num_images, nz, 1, 1, device=device)
117
  fake = netG(noise).detach().cpu()
118
+ return process_tensor_to_image(fake)
 
 
119
 
120
  # Функция обучения
121
  def train_model():
 
165
  errD = errD_real + errD_fake
166
  optimizerD.step()
167
 
168
+ ############################
169
  # (2) Обновление генератора
170
  ###########################
171
  netG.zero_grad()
 
185
 
186
  # Показать промежуточные результаты
187
  if i % 500 == 0:
188
+ img = process_tensor_to_image(fake)
189
+ if img is not None:
190
+ st.image(img, caption=f'Эпоха {epoch}, Batch {i}')
 
 
191
 
192
  # Обновление прогресс бара
193
  progress_bar.progress((epoch + 1) / num_epochs)
 
198
 
199
  # Основной интерфейс Streamlit
200
  def main():
201
+ st.title('DCGAN Training and Generation')
202
+
203
+ # Настройка боковой панели
204
+ st.sidebar.title('Параметры')
205
+ global num_epochs, batch_size, lr
206
+ num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5)
207
+ batch_size = st.sidebar.slider('Размер батча', 16, 128, 64)
208
+ lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002)
209
 
210
  # Выбор режима
211
  mode = st.sidebar.selectbox('Выберите режим',
 
224
  # Генерация финальных изображений
225
  st.subheader('Финальные сгенерированные изображения')
226
  final_images = generate_images(netG)
227
+ if final_images is not None:
228
+ st.image(final_images, caption='Финальные сгенерированные изображения')
229
 
230
  elif mode == 'Генерация':
231
  if os.path.exists('generator.pth'):
 
237
 
238
  if st.button('Сгенерировать изображения'):
239
  images = generate_images(netG, num_images)
240
+ if images is not None:
241
+ st.image(images, caption='Сгенерированные изображения')
242
+
243
+ # Опция сохранения
244
+ if st.button('Сохранить изображения'):
245
+ images.save('generated_images.png')
246
+ st.success('Изображения сохранены!')
247
  else:
248
  st.error('Модель не найдена. Пожалуйста, сначала обучите модель.')
249
 
 
 
 
 
250
  # Дополнительные настройки
251
  st.sidebar.markdown("""
252
  ## О проекте
 
262
 
263
  # Настройки кэширования
264
  if st.sidebar.checkbox('Очистить кэш'):
265
+ st.cache_data.clear()
266
  st.success('Кэш очищен!')
267
 
268
+ # Запуск приложения
269
+ if __name__ == '__main__':
270
+ main()