import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision.utils as vutils import torchvision.datasets as dset import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt from PIL import Image import os # Параметры nc = 3 # Количество каналов в изображении nz = 100 # Размер вектора шума ngf = 64 # Размер карт признаков генератора ndf = 64 # Размер карт признаков дискриминатора num_epochs = 5 # Количество эпох обучения lr = 0.0002 # Скорость обучения beta1 = 0.5 # Beta1 для Adam оптимизатора batch_size = 64 # Размер батча image_size = 64 # Размер изображения # Генератор class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input): return self.main(input) # Дискриминатор class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input) # Настройка Streamlit st.title('DCGAN Training and Generation') # Создание боковой панели st.sidebar.title('Параметры') num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5) batch_size = st.sidebar.slider('Размер батча', 16, 128, 64) lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002) # Загрузка данных @st.cache_data def load_data(): dataset = dset.CIFAR10(root='./data', download=True, transform=transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) return dataloader # Функция для визуализации результатов def plot_training_results(G_losses, D_losses): fig, ax = plt.subplots(figsize=(10, 5)) plt.plot(G_losses, label='Generator Loss') plt.plot(D_losses, label='Discriminator Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.legend() st.pyplot(fig) # Функция генерации изображений def generate_images(netG, num_images=64): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): noise = torch.randn(num_images, nz, 1, 1, device=device) fake = netG(noise).detach().cpu() img = vutils.make_grid(fake, padding=2, normalize=True) img = np.transpose(img, (1, 2, 0)) return img # Функция обучения def train_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.write(f"Using device: {device}") # Создание сетей netG = Generator().to(device) netD = Discriminator().to(device) # Критерий и оптимизаторы criterion = nn.BCELoss() optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) # Загрузка данных dataloader = load_data() # Списки для отслеживания прогресса G_losses = [] D_losses = [] # Прогресс бар progress_bar = st.progress(0) status_text = st.empty() # Обучение for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): ############################ # (1) Обновление дискриминатора ########################### netD.zero_grad() real = data[0].to(device) b_size = real.size(0) label = torch.full((b_size,), 1, dtype=torch.float, device=device) output = netD(real).view(-1) errD_real = criterion(output, label) errD_real.backward() noise = torch.randn(b_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(0) output = netD(fake.detach()).view(-1) errD_fake = criterion(output, label) errD_fake.backward() errD = errD_real + errD_fake optimizerD.step() ############################ # (2) Обновление генератора ########################### netG.zero_grad() label.fill_(1) output = netD(fake).view(-1) errG = criterion(output, label) errG.backward() optimizerG.step() # Обновление статуса if i % 100 == 0: status_text.text(f'Эпоха [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] ' f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}') G_losses.append(errG.item()) D_losses.append(errD.item()) # Показать промежуточные результаты if i % 500 == 0: with torch.no_grad(): fake = netG(torch.randn(64, nz, 1, 1, device=device)).detach().cpu() img = vutils.make_grid(fake, padding=2, normalize=True) img = np.transpose(img, (1, 2, 0)) st.image(img, caption=f'Эпоха {epoch}, Batch {i}') # Обновление прогресс бара progress_bar.progress((epoch + 1) / num_epochs) # Сохранение модели torch.save(netG.state_dict(), 'generator.pth') return netG, G_losses, D_losses # Основной интерфейс Streamlit def main(): st.sidebar.title('DCGAN Control Panel') # Выбор режима mode = st.sidebar.selectbox('Выберите режим', ['Обучение', 'Генерация']) if mode == 'Обучение': if st.button('Начать обучение'): st.write('Начинаем обучение...') netG, G_losses, D_losses = train_model() st.write('Обучение завершено!') # Отображение графиков потерь st.subheader('Графики потерь') plot_training_results(G_losses, D_losses) # Генерация финальных изображений st.subheader('Финальные сгенерированные изображения') final_images = generate_images(netG) st.image(final_images, caption='Финальные сгенерированные изображения') elif mode == 'Генерация': if os.path.exists('generator.pth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") netG = Generator().to(device) netG.load_state_dict(torch.load('generator.pth', map_location=device)) num_images = st.slider('Количество изображений', 1, 64, 16) if st.button('Сгенерировать изображения'): images = generate_images(netG, num_images) st.image(images, caption='Сгенерированные изображения') # Опция сохранения if st.button('Сохранить изображения'): im = Image.fromarray((images * 255).astype(np.uint8)) im.save('generated_images.png') st.success('Изображения сохранены!') else: st.error('Модель не найдена. Пожалуйста, сначала обучите модель.') # Запуск приложения if __name__ == '__main__': main() # Дополнительные настройки st.sidebar.markdown(""" ## О проекте Это приложение демонстрирует работу DCGAN (Deep Convolutional Generative Adversarial Network) для генерации изображений. ### Особенности: - Обучение на датасете CIFAR-10 - Генерация изображений 64x64 - Возможность настройки параметров - Визуализация процесса обучения """) # Настройки кэширования if st.sidebar.checkbox('Очистить кэш'): st.caching.clear_cache() st.success('Кэш очищен!') # Дополнительные метрики if st.sidebar.checkbox('Показать дополнительные метрики'): st.sidebar.write(f'Размер батча: {batch_size}') st.sidebar.write(f'Количество эпох: {num_epochs}') st.sidebar.write(f'Скорость обучения: {lr}')