Dmtlant commited on
Commit
6b694b5
·
verified ·
1 Parent(s): 5b29361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -262
app.py CHANGED
@@ -1,270 +1,24 @@
1
  import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- import torchvision.utils as vutils
6
- import torchvision.datasets as dset
7
- import torchvision.transforms as transforms
8
- import numpy as np
9
- import matplotlib.pyplot as plt
10
- from PIL import Image
11
- import os
12
- import sys
13
 
14
- # Проверка правильного запуска
15
- if not 'streamlit' in sys.modules:
16
- print("Пожалуйста, запустите приложение с помощью команды:")
17
- print("streamlit run dcgan_app.py")
18
- sys.exit(1)
19
 
20
- # Параметры
21
- nc = 3 # Количество каналов в изображении
22
- nz = 100 # Размер вектора шума
23
- ngf = 64 # Размер карт признаков генератора
24
- ndf = 64 # Размер карт признаков дискриминатора
25
- num_epochs = 5 # Количество эпох обучения
26
- lr = 0.0002 # Скорость обучения
27
- beta1 = 0.5 # Beta1 для Adam оптимизатора
28
- batch_size = 64 # Размер батча
29
- image_size = 64 # Размер изображения
30
-
31
- # Генератор
32
- class Generator(nn.Module):
33
- def __init__(self):
34
- super(Generator, self).__init__()
35
- self.main = nn.Sequential(
36
- nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
37
- nn.BatchNorm2d(ngf * 8),
38
- nn.ReLU(True),
39
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
40
- nn.BatchNorm2d(ngf * 4),
41
- nn.ReLU(True),
42
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
43
- nn.BatchNorm2d(ngf * 2),
44
- nn.ReLU(True),
45
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
46
- nn.BatchNorm2d(ngf),
47
- nn.ReLU(True),
48
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
49
- nn.Tanh()
50
- )
51
-
52
- def forward(self, input):
53
- return self.main(input)
54
-
55
- # Дискриминатор
56
- class Discriminator(nn.Module):
57
- def __init__(self):
58
- super(Discriminator, self).__init__()
59
- self.main = nn.Sequential(
60
- nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
61
- nn.LeakyReLU(0.2, inplace=True),
62
- nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
63
- nn.BatchNorm2d(ndf * 2),
64
- nn.LeakyReLU(0.2, inplace=True),
65
- nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
66
- nn.BatchNorm2d(ndf * 4),
67
- nn.LeakyReLU(0.2, inplace=True),
68
- nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
69
- nn.BatchNorm2d(ndf * 8),
70
- nn.LeakyReLU(0.2, inplace=True),
71
- nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
72
- nn.Sigmoid()
73
- )
74
-
75
- def forward(self, input):
76
- return self.main(input)
77
-
78
- # Функция для обработки тензора в изображение
79
- def process_tensor_to_image(tensor):
80
  try:
81
- img = vutils.make_grid(tensor, padding=2, normalize=True)
82
- img = np.transpose(img, (1, 2, 0))
83
- img = ((img + 1) * 127.5).astype(np.uint8)
84
- return Image.fromarray(img)
85
- except Exception as e:
86
- st.error(f"Ошибка при обработке изображения: {e}")
87
  return None
88
 
89
- # Загрузка данных
90
- @st.cache_data
91
- def load_data():
92
- dataset = dset.CIFAR10(root='./data', download=True,
93
- transform=transforms.Compose([
94
- transforms.Resize(image_size),
95
- transforms.ToTensor(),
96
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
97
- ]))
98
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
99
- shuffle=True, num_workers=2)
100
- return dataloader
101
-
102
- # Функция для визуализации результатов
103
- def plot_training_results(G_losses, D_losses):
104
- fig, ax = plt.subplots(figsize=(10, 5))
105
- plt.plot(G_losses, label='Generator Loss')
106
- plt.plot(D_losses, label='Discriminator Loss')
107
- plt.xlabel('Iterations')
108
- plt.ylabel('Loss')
109
- plt.legend()
110
- st.pyplot(fig)
111
-
112
- # Функция генерации изображений
113
- def generate_images(netG, num_images=64):
114
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
- with torch.no_grad():
116
- noise = torch.randn(num_images, nz, 1, 1, device=device)
117
- fake = netG(noise).detach().cpu()
118
- return process_tensor_to_image(fake)
119
-
120
- # Функция обучения
121
- def train_model():
122
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
- st.write(f"Using device: {device}")
124
-
125
- # Создание сетей
126
- netG = Generator().to(device)
127
- netD = Discriminator().to(device)
128
-
129
- # Критерий и оптимизаторы
130
- criterion = nn.BCELoss()
131
- optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
132
- optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
133
-
134
- # Загрузка данных
135
- dataloader = load_data()
136
-
137
- # Списки для отслеживания прогресса
138
- G_losses = []
139
- D_losses = []
140
-
141
- # Прогресс бар
142
- progress_bar = st.progress(0)
143
- status_text = st.empty()
144
-
145
- # Обучение
146
- for epoch in range(num_epochs):
147
- for i, data in enumerate(dataloader, 0):
148
- ############################
149
- # (1) Обновление дискриминатора
150
- ###########################
151
- netD.zero_grad()
152
- real = data[0].to(device)
153
- b_size = real.size(0)
154
- label = torch.full((b_size,), 1, dtype=torch.float, device=device)
155
- output = netD(real).view(-1)
156
- errD_real = criterion(output, label)
157
- errD_real.backward()
158
-
159
- noise = torch.randn(b_size, nz, 1, 1, device=device)
160
- fake = netG(noise)
161
- label.fill_(0)
162
- output = netD(fake.detach()).view(-1)
163
- errD_fake = criterion(output, label)
164
- errD_fake.backward()
165
- errD = errD_real + errD_fake
166
- optimizerD.step()
167
-
168
- ############################
169
- # (2) Обновление генератора
170
- ###########################
171
- netG.zero_grad()
172
- label.fill_(1)
173
- output = netD(fake).view(-1)
174
- errG = criterion(output, label)
175
- errG.backward()
176
- optimizerG.step()
177
-
178
- # Обновление статуса
179
- if i % 100 == 0:
180
- status_text.text(f'Эпоха [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
181
- f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}')
182
-
183
- G_losses.append(errG.item())
184
- D_losses.append(errD.item())
185
-
186
- # Показать промежуточные результаты
187
- if i % 500 == 0:
188
- img = process_tensor_to_image(fake)
189
- if img is not None:
190
- st.image(img, caption=f'Эпоха {epoch}, Batch {i}')
191
-
192
- # Обновление прогресс бара
193
- progress_bar.progress((epoch + 1) / num_epochs)
194
-
195
- # Сохранение модели
196
- torch.save(netG.state_dict(), 'generator.pth')
197
- return netG, G_losses, D_losses
198
-
199
- # Основной интерфейс Streamlit
200
- def main():
201
- st.title('DCGAN Training and Generation')
202
-
203
- # Настройка боковой панели
204
- st.sidebar.title('Параметры')
205
- global num_epochs, batch_size, lr
206
- num_epochs = st.sidebar.slider('Количество эпох', 1, 50, 5)
207
- batch_size = st.sidebar.slider('Размер батча', 16, 128, 64)
208
- lr = st.sidebar.number_input('Скорость обучения', 0.0001, 0.01, 0.0002)
209
-
210
- # Выбор режима
211
- mode = st.sidebar.selectbox('Выберите режим',
212
- ['Обучение', 'Генерация'])
213
-
214
- if mode == 'Обучение':
215
- if st.button('Начать обучение'):
216
- st.write('Начинаем обучение...')
217
- netG, G_losses, D_losses = train_model()
218
- st.write('Обучение завершено!')
219
-
220
- # Отображение графиков потерь
221
- st.subheader('Графики потерь')
222
- plot_training_results(G_losses, D_losses)
223
-
224
- # Генерация финальных изображений
225
- st.subheader('Финальные сгенерированные изображения')
226
- final_images = generate_images(netG)
227
- if final_images is not None:
228
- st.image(final_images, caption='Финальные сгенерированные изображения')
229
-
230
- elif mode == 'Генерация':
231
- if os.path.exists('generator.pth'):
232
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
- netG = Generator().to(device)
234
- netG.load_state_dict(torch.load('generator.pth', map_location=device))
235
-
236
- num_images = st.slider('Количество изоб��ажений', 1, 64, 16)
237
-
238
- if st.button('Сгенерировать изображения'):
239
- images = generate_images(netG, num_images)
240
- if images is not None:
241
- st.image(images, caption='Сгенерированные изображения')
242
-
243
- # Опция сохранения
244
- if st.button('Сохранить изображения'):
245
- images.save('generated_images.png')
246
- st.success('Изображения сохранены!')
247
- else:
248
- st.error('Модель не найдена. Пожалуйста, сначала обучите модель.')
249
-
250
- # Дополнительные настройки
251
- st.sidebar.markdown("""
252
- ## О проекте
253
- Это приложение демонстрирует работу DCGAN (Deep Convolutional Generative Adversarial Network)
254
- для генерации изображений.
255
-
256
- ### Особенности:
257
- - Обучение на датасете CIFAR-10
258
- - Генерация изображений 64x64
259
- - Возможность настройки параметров
260
- - Визуализация процесса обучения
261
- """)
262
 
263
- # Настройки кэширования
264
- if st.sidebar.checkbox('Очистить кэш'):
265
- st.cache_data.clear()
266
- st.success('Кэш очищен!')
267
 
268
- # Запуск приложения
269
- if __name__ == '__main__':
270
- main()
 
 
 
1
  import streamlit as st
2
+ import requests
 
 
 
 
 
 
 
 
 
 
3
 
4
+ API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large-v3-turbo"
5
+ headers = {"Authorization": f"Bearer {st.secrets['HF_API_KEY']}"} # Безопасное хранение токена
 
 
 
6
 
7
+ def query(file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  try:
9
+ response = requests.post(API_URL, headers=headers, data=file.read())
10
+ response.raise_for_status() # Проверка на ошибки HTTP
11
+ return response.json()
12
+ except requests.exceptions.RequestException as e:
13
+ st.error(f"Ошибка запроса к API: {e}")
 
14
  return None
15
 
16
+ st.title("Транскрипция аудио")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["wav", "mp3", "flac"])
 
 
 
19
 
20
+ if uploaded_file is not None:
21
+ with st.spinner("Транскрибируется..."):
22
+ output = query(uploaded_file)
23
+ if output:
24
+ st.text_area("Транскрипт:", value=output["text"])