|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision.utils as vutils |
|
import torchvision.datasets as dset |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import os |
|
|
|
|
|
nc = 3 |
|
nz = 100 |
|
ngf = 64 |
|
ndf = 64 |
|
num_epochs = 5 |
|
lr = 0.0002 |
|
beta1 = 0.5 |
|
batch_size = 64 |
|
image_size = 64 |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self): |
|
super(Generator, self).__init__() |
|
self.main = nn.Sequential( |
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 8), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, input): |
|
return self.main(input) |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self): |
|
super(Discriminator, self).__init__() |
|
self.main = nn.Sequential( |
|
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 2), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 4), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ndf * 8), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, input): |
|
return self.main(input) |
|
|
|
|
|
st.title('DCGAN Training and Generation') |
|
|
|
|
|
st.sidebar.title('Параметры') |
|
num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5) |
|
batch_size = st.sidebar.slider('Размер батча', 16, 128, 64) |
|
lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002) |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
dataset = dset.CIFAR10(root='./data', download=True, |
|
transform=transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
])) |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, |
|
shuffle=True, num_workers=2) |
|
return dataloader |
|
|
|
|
|
def plot_training_results(G_losses, D_losses): |
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
plt.plot(G_losses, label='Generator Loss') |
|
plt.plot(D_losses, label='Discriminator Loss') |
|
plt.xlabel('Iterations') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
st.pyplot(fig) |
|
|
|
|
|
def generate_images(netG, num_images=64): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
with torch.no_grad(): |
|
noise = torch.randn(num_images, nz, 1, 1, device=device) |
|
fake = netG(noise).detach().cpu() |
|
img = vutils.make_grid(fake, padding=2, normalize=True) |
|
img = np.transpose(img, (1, 2, 0)) |
|
return img |
|
|
|
|
|
def train_model(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
st.write(f"Using device: {device}") |
|
|
|
|
|
netG = Generator().to(device) |
|
netD = Discriminator().to(device) |
|
|
|
|
|
criterion = nn.BCELoss() |
|
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) |
|
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) |
|
|
|
|
|
dataloader = load_data() |
|
|
|
|
|
G_losses = [] |
|
D_losses = [] |
|
|
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
|
|
for epoch in range(num_epochs): |
|
for i, data in enumerate(dataloader, 0): |
|
|
|
|
|
|
|
netD.zero_grad() |
|
real = data[0].to(device) |
|
b_size = real.size(0) |
|
label = torch.full((b_size,), 1, dtype=torch.float, device=device) |
|
output = netD(real).view(-1) |
|
errD_real = criterion(output, label) |
|
errD_real.backward() |
|
|
|
noise = torch.randn(b_size, nz, 1, 1, device=device) |
|
fake = netG(noise) |
|
label.fill_(0) |
|
output = netD(fake.detach()).view(-1) |
|
errD_fake = criterion(output, label) |
|
errD_fake.backward() |
|
errD = errD_real + errD_fake |
|
optimizerD.step() |
|
|
|
|
|
|
|
|
|
netG.zero_grad() |
|
label.fill_(1) |
|
output = netD(fake).view(-1) |
|
errG = criterion(output, label) |
|
errG.backward() |
|
optimizerG.step() |
|
|
|
|
|
if i % 100 == 0: |
|
status_text.text(f'Эпоха [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] ' |
|
f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}') |
|
|
|
G_losses.append(errG.item()) |
|
D_losses.append(errD.item()) |
|
|
|
|
|
if i % 500 == 0: |
|
with torch.no_grad(): |
|
fake = netG(torch.randn(64, nz, 1, 1, device=device)).detach().cpu() |
|
img = vutils.make_grid(fake, padding=2, normalize=True) |
|
img = np.transpose(img, (1, 2, 0)) |
|
st.image(img, caption=f'Эпоха {epoch}, Batch {i}') |
|
|
|
|
|
progress_bar.progress((epoch + 1) / num_epochs) |
|
|
|
|
|
torch.save(netG.state_dict(), 'generator.pth') |
|
return netG, G_losses, D_losses |
|
|
|
|
|
def main(): |
|
st.sidebar.title('DCGAN Control Panel') |
|
|
|
|
|
mode = st.sidebar.selectbox('Выберите режим', |
|
['Обучение', 'Генерация']) |
|
|
|
if mode == 'Обучение': |
|
if st.button('Начать обучение'): |
|
st.write('Начинаем обучение...') |
|
netG, G_losses, D_losses = train_model() |
|
st.write('Обучение завершено!') |
|
|
|
|
|
st.subheader('Графики потерь') |
|
plot_training_results(G_losses, D_losses) |
|
|
|
|
|
st.subheader('Финальные сгенерированные изображения') |
|
final_images = generate_images(netG) |
|
st.image(final_images, caption='Финальные сгенерированные изображения') |
|
|
|
elif mode == 'Генерация': |
|
if os.path.exists('generator.pth'): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
netG = Generator().to(device) |
|
netG.load_state_dict(torch.load('generator.pth', map_location=device)) |
|
|
|
num_images = st.slider('Количество изображений', 1, 64, 16) |
|
|
|
if st.button('Сгенерировать изображения'): |
|
images = generate_images(netG, num_images) |
|
st.image(images, caption='Сгенерированные изображения') |
|
|
|
|
|
if st.button('Сохранить изображения'): |
|
im = Image.fromarray((images * 255).astype(np.uint8)) |
|
im.save('generated_images.png') |
|
st.success('Изображения сохранены!') |
|
else: |
|
st.error('Модель не найдена. Пожалуйста, сначала обучите модель.') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
st.sidebar.markdown(""" |
|
## О проекте |
|
Это приложение демонстрирует работу DCGAN (Deep Convolutional Generative Adversarial Network) |
|
для генерации изображений. |
|
|
|
### Особенности: |
|
- Обучение на датасете CIFAR-10 |
|
- Генерация изображений 64x64 |
|
- Возможность настройки параметров |
|
- Визуализация процесса обучения |
|
""") |
|
|
|
|
|
if st.sidebar.checkbox('Очистить кэш'): |
|
st.caching.clear_cache() |
|
st.success('Кэш очищен!') |
|
|
|
|
|
if st.sidebar.checkbox('Показать дополнительные метрики'): |
|
st.sidebar.write(f'Размер батча: {batch_size}') |
|
st.sidebar.write(f'Количество эпох: {num_epochs}') |
|
st.sidebar.write(f'Скорость обучения: {lr}') |