itslukeypookie commited on
Commit
a888d12
·
verified ·
1 Parent(s): e0b9fd5
Files changed (1) hide show
  1. main.py +560 -0
main.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###### Train CIFAR10 with PyTorch. ######
2
+
3
+ ### IMPORT DEPENDENCIES
4
+
5
+ from torch.utils.data import DataLoader
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import torch.nn.functional as F
10
+ import torch.backends.cudnn as cudnn
11
+ import gradio as gr
12
+ import wandb
13
+ import math
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+ import torchvision
19
+ import torchvision.transforms as transforms
20
+ import torchvision.models as models
21
+ import torch.optim.lr_scheduler as lr_scheduler
22
+ import os
23
+ import argparse
24
+ import torchattacks
25
+
26
+ from models import *
27
+
28
+ from tqdm import tqdm
29
+ from PIL import Image
30
+ import gradio as gr
31
+
32
+ # from utils import progress_bar
33
+
34
+ # CSS theme styling
35
+ theme = gr.themes.Base(
36
+ font=[gr.themes.GoogleFont('Montserrat'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
37
+ primary_hue="emerald",
38
+ secondary_hue="emerald",
39
+ neutral_hue="zinc"
40
+ ).set(
41
+ body_text_color='*neutral_950',
42
+ body_text_color_subdued='*neutral_950',
43
+ block_shadow='*shadow_drop_lg',
44
+ button_shadow='*shadow_drop_lg',
45
+ block_title_text_color='*neutral_950',
46
+ block_title_text_weight='500',
47
+ slider_color='*secondary_600'
48
+ )
49
+
50
+ def normalize(img):
51
+ min_im = np.min(img)
52
+ np_img = img - min_im
53
+ max_im = np.max(np_img)
54
+ np_img /= max_im
55
+ return np_img
56
+
57
+ def imshow(img, fig_name = "test_input.png"):
58
+ try:
59
+ img = img.clone().detach().cpu().numpy()
60
+ except:
61
+ print('img already numpy')
62
+
63
+ plt.imshow(normalize(np.transpose(img, (1, 2, 0))))
64
+ plt.savefig(fig_name)
65
+ print(f'Figure saved as {fig_name}')
66
+ return fig_name
67
+
68
+ def class_names(class_num, class_list): # converts the raw number label to text
69
+ if (class_num < 0) and (class_num >= 10):
70
+ gr.Warning("Class List Error")
71
+ return
72
+ return class_list[class_num]
73
+
74
+
75
+ ### MAIN FUNCTION
76
+ best_acc = 0
77
+ def main(drop_type, epochs_sldr, train_sldr, test_sldr, learning_rate, optimizer, sigma_sldr, adv_attack, username, scheduler):
78
+
79
+ ## Input protection
80
+ if not drop_type:
81
+ gr.Warning("Please select a model from the dropdown.")
82
+ return
83
+ if not username:
84
+ gr.Warning("Please enter a WandB username.")
85
+ return
86
+ if(epochs_sldr % 1 != 0):
87
+ gr.Warning("Number of epochs must be an integer.")
88
+ return
89
+ if(train_sldr % 1 != 0):
90
+ gr.Warning("Training batch size must be an integer.")
91
+ return
92
+ if(test_sldr % 1 != 0):
93
+ gr.Warning("Testing batch size must be an integer.")
94
+ return
95
+
96
+ num_epochs = int(epochs_sldr)
97
+ global learn_batch
98
+ learn_batch = int(train_sldr)
99
+ global test_batch
100
+ test_batch = int(test_sldr)
101
+ learning_rate = float(learning_rate)
102
+ optimizer_choose = str(optimizer)
103
+ sigma = float(sigma_sldr)
104
+ attack = str(adv_attack)
105
+ scheduler_choose = str(scheduler)
106
+
107
+ # REPLACE ENTITY WITH USERNAME BELOW
108
+ wandb.init(entity=username, project="model-training")
109
+
110
+ parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
111
+ parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
112
+ parser.add_argument('--resume', '-r', action='store_true',
113
+ help='resume from checkpoint')
114
+ args = parser.parse_args()
115
+
116
+ if torch.cuda.is_available():
117
+ device = 'cuda'
118
+ gr.Info("Cuda detected - running on Cuda")
119
+ elif torch.backends.mps.is_available():
120
+ device = 'mps'
121
+ gr.Info("MPS detected - running on Metal")
122
+ else:
123
+ device = 'cpu'
124
+ gr.Info("No GPU Detected - running on CPU")
125
+
126
+ start_epoch = 0 # start from epoch 0 or last checkpoint epoch
127
+
128
+ ## Data
129
+ try:
130
+ print('==> Preparing data..')
131
+ transform_train = transforms.Compose([
132
+ transforms.RandomCrop(32, padding=4),
133
+ transforms.RandomHorizontalFlip(),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
136
+ ])
137
+
138
+ transform_test = transforms.Compose([
139
+ transforms.ToTensor(),
140
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
141
+ ])
142
+
143
+ trainset = torchvision.datasets.CIFAR10(
144
+ root='./data', train=True, download=True, transform=transform_train)
145
+ trainloader = DataLoader(
146
+ trainset, batch_size=learn_batch, shuffle=True, num_workers=2)
147
+
148
+ testset = torchvision.datasets.CIFAR10(
149
+ root='./data', train=False, download=True, transform=transform_test)
150
+ testloader = DataLoader(
151
+ testset, batch_size=test_batch, shuffle=True, num_workers=2)
152
+
153
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
154
+ except Exception as e:
155
+ print(f"Error: {e}")
156
+ gr.Warning(f"Data Loading Error: {e}")
157
+
158
+ ## Model
159
+ try:
160
+ print('==> Building model..')
161
+ net = models_dict.get(drop_type, None)
162
+
163
+ # Make list of models containing either classifer or fc functions
164
+ classifier_models = ['ConvNext_Small', 'ConvNext_Base', 'ConvNext_Large', 'DenseNet', 'EfficientNet_B0', 'MobileNetV2',
165
+ 'MaxVit', 'MnasNet0_5', 'SqueezeNet', 'VGG19']
166
+ fc_models = ['GoogLeNet', 'InceptionNetV3', 'RegNet_X_400MF', 'ResNet18', 'ShuffleNet_V2_X0_5']
167
+
168
+ # Check dropdown choice for fc or classifier function implementation
169
+ if net in classifier_models:
170
+ num_ftrs = net.classifier[-1].in_features
171
+ net.classifier[-1] = torch.nn.Linear(num_ftrs, len(classes))
172
+ elif net in fc_models:
173
+ num_ftrs = net.fc.in_features
174
+ net.fc = torch.nn.Linear(num_ftrs, len(classes))
175
+
176
+ net = net.to(device)
177
+
178
+ except Exception as e:
179
+ print(f"Error: {e}")
180
+ gr.Warning(f"Model Building Error: {e}")
181
+
182
+ # if args.resume:
183
+ # # Load checkpoint.
184
+ # print('==> Resuming from checkpoint..')
185
+ # assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
186
+ # checkpoint = torch.load('./checkpoint/ckpt.pth')
187
+ # net.load_state_dict(checkpoint['net'])
188
+ # best_acc = checkpoint['acc']
189
+ # start_epoch = checkpoint['epoch']
190
+
191
+ SGDopt = optim.SGD(net.parameters(), lr=learning_rate,momentum=0.9, weight_decay=5e-4)
192
+ Adamopt = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4)
193
+
194
+ criterion = nn.CrossEntropyLoss()
195
+
196
+ if optimizer_choose == "SGD":
197
+ optimizer = SGDopt
198
+ elif optimizer_choose == "Adam":
199
+ optimizer = Adamopt
200
+ print (f'optimizer: {optimizer}')
201
+
202
+ #scheduler = lr_scheduler.LinearLR(optimizer, start_factor=learning_rate, end_factor=0.0001, total_iters=10)
203
+ if scheduler_choose == "CosineAnnealingLR":
204
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
205
+ elif scheduler_choose == "ReduceLROnPlateau":
206
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5)
207
+ elif scheduler_choose == "StepLR":
208
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=30)
209
+ print (f'scheduler: {scheduler_choose}')
210
+
211
+ img_labels = [] # initialize list for label generation
212
+ raw_image_list = [] # initialize list for image generation
213
+ img_list1 = [] # initialize list for combined image/labels
214
+ img_list2 = [] # initialize list for gaussian image generation
215
+ img_list3 = [] # initialize list for adversarial attack image generation
216
+
217
+ # The following lists are used when generating all images in an epoch instead of 10:
218
+ full_img_labels = []
219
+ full_raw_image_list = []
220
+ full_img_list1 = []
221
+
222
+ adv_num = 1 # initialize adversarial image number for naming purposes
223
+ global gaussian_num
224
+ gaussian_num = 1 # initialize gaussian noise image number for naming purposes
225
+
226
+ for epoch in range(start_epoch, start_epoch+epochs_sldr):
227
+ if sigma == 0:
228
+ train(epoch, net, trainloader, device, optimizer, criterion, sigma)
229
+ else:
230
+ gaussian_fig = train(epoch, net, trainloader, device, optimizer, criterion, sigma)
231
+ acc, predicted = test(epoch, net, testloader, device, criterion)
232
+
233
+ if scheduler_choose == "ReduceLROnPlateau":
234
+ scheduler.step(metrics=acc)
235
+ elif not scheduler_choose == "None":
236
+ scheduler.step()
237
+
238
+ if (((epoch-1) % 10 == 0) or (epoch == 0)) and (epoch != 1): # generate images every 10 epochs (and the 0th epoch)
239
+ dataiter = iter(testloader)
240
+ imgs, labels = next(dataiter)
241
+ normalized_imgs = (imgs-imgs.min())/(imgs.max()-imgs.min())
242
+ atk = torchattacks.PGD(net, eps=0.00015, alpha=0.0000000000000001, steps=7)
243
+ if attack == "Yes":
244
+ if normalized_imgs is None:
245
+ print("error occured")
246
+ else:
247
+ print(torch.std(normalized_imgs))
248
+ atk.set_normalization_used(mean = torch.mean(normalized_imgs,axis=[0,2,3]), std=torch.std(normalized_imgs,axis=[0,2,3])/1.125)
249
+ adv_images = atk(imgs, labels)
250
+ fig_name = imshow(adv_images[0], fig_name = f'figures/adversarial_attack{adv_num}.png')
251
+ attack_fig = Image.open(fig_name)
252
+ for i in range(1): # generate 1 image per epoch
253
+ img_list3.append(attack_fig)
254
+ adv_num = adv_num + 1
255
+ for i in range(10): # generate 10 images per epoch
256
+ gradio_imgs = transforms.functional.to_pil_image(normalized_imgs[i])
257
+ raw_image_list.append(gradio_imgs)
258
+ predicted_text = class_names(predicted[i].item(), classes)
259
+ actual_text = class_names(labels[i].item(), classes)
260
+ label_text = f'Epoch: {epoch} | Predicted: {predicted_text} | Actual: {actual_text}'
261
+ img_labels.append(label_text)
262
+ for i in range(test_batch): # generate all images per epoch
263
+ full_gradio_imgs = transforms.functional.to_pil_image(normalized_imgs[i])
264
+ full_raw_image_list.append(full_gradio_imgs)
265
+ full_predicted_text = class_names(predicted[i].item(), classes)
266
+ full_actual_text = class_names(labels[i].item(), classes)
267
+ full_label_text = f'Epoch: {epoch} | Predicted: {full_predicted_text} | Actual: {full_actual_text}'
268
+ full_img_labels.append(full_label_text)
269
+ for i in range(len(raw_image_list)):
270
+ img_tuple = (raw_image_list[i], img_labels[i])
271
+ img_list1.append(img_tuple)
272
+ for i in range(len(full_raw_image_list)):
273
+ full_img_tuple = (full_raw_image_list[i], full_img_labels[i])
274
+ full_img_list1.append(full_img_tuple)
275
+ if sigma != 0:
276
+ for i in range(1): # generate 1 image per epoch
277
+ img_list2.append(gaussian_fig)
278
+ gaussian_num = gaussian_num + 1
279
+ if (sigma == 0) and (attack == "No"):
280
+ return str(acc)+"%", img_list1, full_img_list1, None, None
281
+ elif (sigma != 0) and (attack == "No"):
282
+ return str(acc)+"%", img_list1, full_img_list1, img_list2, None
283
+ elif (sigma == 0) and (attack == "Yes"):
284
+ return str(acc)+"%", img_list1, full_img_list1, None, img_list3
285
+ else:
286
+ return str(acc)+"%", img_list1, full_img_list1, img_list2, img_list3
287
+
288
+
289
+
290
+ ### TRAINING
291
+ def train(epoch, net, trainloader, device, optimizer, criterion, sigma, progress=gr.Progress()):
292
+ try:
293
+ print('\nEpoch: %d' % epoch)
294
+ net.train()
295
+ train_loss = 0
296
+ correct = 0
297
+ total = 0
298
+
299
+ iter_float = 50000/learn_batch
300
+ iterations = math.ceil(iter_float)
301
+ iter_prog = 0
302
+
303
+ for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader)):
304
+ if sigma == 0:
305
+ inputs, targets = inputs.to(device), targets.to(device)
306
+ optimizer.zero_grad()
307
+ outputs = net(inputs)
308
+ else:
309
+ noise = np.random.normal(0, sigma, inputs.shape)
310
+ inputs += torch.tensor(noise)
311
+ inputs, targets = inputs.to(device), targets.to(device)
312
+ optimizer.zero_grad()
313
+ outputs = net(inputs)
314
+ n_inputs = inputs.clone().detach().cpu().numpy()
315
+ if(batch_idx%99 == 0):
316
+ fig_name = imshow(n_inputs[0], fig_name= f'figures/gaussian_noise{gaussian_num}.png')
317
+ gaussian_fig = Image.open(fig_name)
318
+
319
+ loss = criterion(outputs, targets)
320
+ loss.backward()
321
+ optimizer.step()
322
+
323
+ train_loss += loss.item()
324
+ _, predicted = outputs.max(1)
325
+ total += targets.size(0)
326
+ correct += predicted.eq(targets).sum().item()
327
+
328
+ iter_prog = iter_prog + 1 # Iterating iteration amount
329
+ progress(iter_prog/iterations, desc=f"Training Epoch {epoch}", total=iterations)
330
+
331
+
332
+ # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
333
+ # % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
334
+
335
+ except Exception as e:
336
+ print(f"Error: {e}")
337
+ gr.Warning(f"Training Error: {e}")
338
+ if sigma != 0:
339
+ return gaussian_fig
340
+
341
+
342
+ ### TESTING
343
+
344
+ def test(epoch, net, testloader, device, criterion, progress = gr.Progress()):
345
+ try:
346
+ net.eval()
347
+ test_loss = 0
348
+ correct = 0
349
+ total = 0
350
+
351
+ iter_float = 10000/test_batch
352
+ iterations = math.ceil(iter_float)
353
+ iter_prog = 0
354
+
355
+ with torch.no_grad():
356
+ for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
357
+ inputs, targets = inputs.to(device), targets.to(device)
358
+ outputs = net(inputs)
359
+ loss = criterion(outputs, targets)
360
+
361
+ test_loss += loss.item()
362
+ _, predicted = outputs.max(1)
363
+ total += targets.size(0)
364
+ correct += predicted.eq(targets).sum().item()
365
+
366
+ iter_prog = iter_prog + 1 # Iterating iteration amount
367
+ progress(iter_prog/iterations, desc=f"Testing Epoch {epoch}", total=iterations)
368
+
369
+ wandb.log({'epoch': epoch+1, 'loss': test_loss})
370
+ wandb.log({"acc": correct/total})
371
+
372
+ # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
373
+ # % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
374
+
375
+ # Save checkpoint.
376
+ global best_acc
377
+ global acc
378
+ acc = 100.*correct/total
379
+ print(acc)
380
+ if acc > best_acc:
381
+ best_acc = acc
382
+ return best_acc, predicted
383
+ else:
384
+ return acc, predicted
385
+ # if acc > best_acc:
386
+ # print('Saving..')
387
+ # state = {
388
+ # 'net': net.state_dict(),
389
+ # 'acc': acc,
390
+ # 'epoch': epoch,
391
+ # }
392
+ # if not os.path.isdir('checkpoint'):
393
+ # os.mkdir('checkpoint')
394
+ # torch.save(state, './checkpoint/ckpt.pth')
395
+ # best_acc = acc
396
+
397
+ except Exception as e:
398
+ print(f"Error: {e}")
399
+ gr.Warning(f"Testing Error: {e}")
400
+
401
+
402
+ models_dict = {
403
+ #"AlexNet": models.AlexNet(weights=models.AlexNet_Weights.DEFAULT),
404
+ #"ConvNext_Small": models.convnext_small(weights=models.ConvNeXt_Small_Weights.DEFAULT),
405
+ #"ConvNext_Base": models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT),
406
+ #"ConvNext_Large": models.convnext_large(weights=models.ConvNeXt_Large_Weights.DEFAULT),
407
+ "DenseNet": models.densenet121(weights=models.DenseNet121_Weights.DEFAULT),
408
+ #"EfficientNet_B0": models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT),
409
+ #"GoogLeNet": models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT),
410
+ # "InceptionNetV3": models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT),
411
+ # "MaxVit": models.maxvit_t(weights=models.MaxVit_T_Weights.DEFAULT),
412
+ #"MnasNet0_5": models.mnasnet0_5(weights=models.MNASNet0_5_Weights.DEFAULT),
413
+ #"MobileNetV2": models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT),
414
+ "ResNet18": models.resnet18(weights=models.ResNet18_Weights.DEFAULT),
415
+ "ResNet50": models.resnet50(weights=models.ResNet50_Weights.DEFAULT),
416
+ #"RegNet_X_400MF": models.regnet_x_400mf(weights=models.RegNet_X_400MF_Weights.DEFAULT),
417
+ #"ShuffleNet_V2_X0_5": models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT),
418
+ #"SqueezeNet": models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.DEFAULT),
419
+ "VGG19": models.vgg19(weights=models.VGG19_Weights.DEFAULT)
420
+ }
421
+
422
+ # Store dictionary keys into list for dropdown menu choices
423
+ names = list(models_dict.keys())
424
+
425
+ # Optimizer names
426
+ optimizers = ["SGD","Adam"]
427
+
428
+ # Scheduler names
429
+ schedulers = ["None","CosineAnnealingLR","ReduceLROnPlateau","StepLR"]
430
+
431
+ ### GRADIO APP INTERFACE
432
+
433
+ def togglepicsettings(choice):
434
+ yes=gr.Gallery(visible=True)
435
+ no=gr.Gallery(visible=False)
436
+ if choice == "Yes":
437
+ return yes,no
438
+ else:
439
+ return no,yes
440
+
441
+ def settings(choice):
442
+ if choice == "Advanced":
443
+ advanced = [
444
+ gr.Slider(visible=True),
445
+ gr.Slider(visible=True),
446
+ gr.Slider(visible=True),
447
+ gr.Dropdown(visible=True),
448
+ gr.Dropdown(visible=True),
449
+ gr.Radio(visible=True)
450
+ ]
451
+ return advanced
452
+ else:
453
+ basic = [
454
+ gr.Slider(visible=False),
455
+ gr.Slider(visible=False),
456
+ gr.Slider(visible=False),
457
+ gr.Dropdown(visible=False),
458
+ gr.Dropdown(visible=False),
459
+ gr.Radio(visible=False)
460
+ ]
461
+ return basic
462
+
463
+ def attacks(choice):
464
+ if choice == "Yes":
465
+ yes = [
466
+ gr.Markdown(visible=True),
467
+ gr.Radio(visible=True),
468
+ gr.Radio(visible=True)
469
+ ]
470
+ return yes
471
+ if choice == "No":
472
+ no = [
473
+ gr.Markdown(visible=False),
474
+ gr.Radio(visible=False),
475
+ gr.Radio(visible=False)
476
+ ]
477
+ return no
478
+
479
+ def gaussian(choice):
480
+ if choice == "Yes":
481
+ yes = [
482
+ gr.Slider(visible=True),
483
+ gr.Gallery(visible=True),
484
+ ]
485
+ return yes
486
+ else:
487
+ no = [
488
+ gr.Slider(visible=False),
489
+ gr.Gallery(visible=False),
490
+ ]
491
+ return no
492
+ def adversarial(choice):
493
+ if choice == "Yes":
494
+ yes = gr.Gallery(visible=True)
495
+ return yes
496
+ else:
497
+ no = gr.Gallery(visible=False)
498
+
499
+ ## Main app for functionality
500
+ with gr.Blocks(css=".caption-label {display:none}") as functionApp:
501
+ with gr.Row():
502
+ gr.Markdown("# CIFAR-10 Model Training GUI")
503
+ with gr.Row():
504
+ gr.Markdown("## Parameters")
505
+ with gr.Row():
506
+ inp = gr.Dropdown(choices=names, label="Training Model", value="ResNet18", info="Choose one of 13 common models provided in the dropdown to use for training.")
507
+ username = gr.Textbox(label="Weights and Biases", info="Enter your username or team name from the Weights and Biases API.")
508
+ epochs_sldr = gr.Slider(label="Number of Epochs", minimum=1, maximum=100, step=1, value=1, info="How many times the model will see the entire dataset during trianing.")
509
+ with gr.Column():
510
+ setting_radio = gr.Radio(["Basic", "Advanced"], label="Settings", value="Basic")
511
+ btn = gr.Button("Run")
512
+ with gr.Row():
513
+ train_sldr = gr.Slider(visible=False, label="Training Batch Size", minimum=1, maximum=1000, step=1, value=128, info="The number of training samples processed before the model's internal parameters are updated.")
514
+ test_sldr = gr.Slider(visible=False, label="Testing Batch Size", minimum=1, maximum=1000, step=1, value=100, info="The number of testing samples processed at once during the evaluation phase.")
515
+ learning_rate_sldr = gr.Slider(visible=False, label="Learning Rate", minimum=0.0001, maximum=0.1, step=0.0001, value=0.001, info="The learning rate of the optimization program.")
516
+ optimizer = gr.Dropdown(visible=False, label="Optimizer", choices=optimizers, value="SGD", info="The optimization algorithm used to minimize the loss function during training.")
517
+ scheduler = gr.Dropdown(visible=False, label="Scheduler", choices=schedulers, value="CosineAnnealingLR", info="The scheduler used to iteratively alter learning rate.")
518
+ use_attacks = gr.Radio(["Yes", "No"], visible=False, label="Use Attacking Methods?", value="No")
519
+ setting_radio.change(fn=settings, inputs=setting_radio, outputs=[train_sldr, test_sldr, learning_rate_sldr, optimizer, scheduler, use_attacks])
520
+ with gr.Row():
521
+ attack_method = gr.Markdown("## Attacking Methods", visible=False)
522
+ with gr.Row():
523
+ use_sigma = gr.Radio(["Yes","No"], visible=False, label="Use Gaussian Noise?", value="No")
524
+ sigma_sldr = gr.Slider(visible=False, label="Gaussian Noise", minimum=0, maximum=1, value=0, step=0.1, info="The sigma value of the gaussian noise eqaution. A value of 0 disables gaussian noise.")
525
+ adv_attack = gr.Radio(["Yes","No"], visible=False, label="Use Adversarial Attacks?", value="No")
526
+ with gr.Row():
527
+ gr.Markdown("## Training Results")
528
+ with gr.Row():
529
+ accuracy = gr.Textbox(label = "Accuracy", info="The validation accuracy of the trained model (accuracy evaluated on testing data).")
530
+ with gr.Column():
531
+ showpics = gr.Radio(["Yes","No"], visible = True, label = "Show all pictures?", value = "No")
532
+ pics = gr.Gallery(preview=False, selected_index=0, object_fit='contain', label="Testing Images")
533
+ allpics = gr.Gallery(preview=True, selected_index=0, object_fit='contain', label="Full Testing Images",visible = False)
534
+ showpics.change(fn=togglepicsettings, inputs=[showpics], outputs = [allpics, pics])
535
+ with gr.Row():
536
+ gaussian_pics = gr.Gallery(visible=False, preview=False, selected_index=0, object_fit='contain', label="Gaussian Noise")
537
+ attack_pics = gr.Gallery(visible=False, preview=False, selected_index=0, object_fit='contain', label="Adversarial Attack")
538
+ use_attacks.change(fn=attacks, inputs=use_attacks, outputs=[attack_method, use_sigma, adv_attack])
539
+ use_sigma.change(fn=gaussian, inputs=use_sigma, outputs=[sigma_sldr, gaussian_pics])
540
+ adv_attack.change(fn=adversarial, inputs=adv_attack, outputs=attack_pics)
541
+ btn.click(fn=main, inputs=[inp, epochs_sldr, train_sldr, test_sldr, learning_rate_sldr, optimizer, sigma_sldr, adv_attack, username, scheduler], outputs=[accuracy, pics, allpics, gaussian_pics, attack_pics])
542
+
543
+ ## Documentation app (implemented as second tab)
544
+
545
+ markdown_file_path = 'documentation.md'
546
+ with open(markdown_file_path, 'r') as file:
547
+ markdown_content = file.read()
548
+
549
+ with gr.Blocks() as documentationApp:
550
+ with gr.Row():
551
+ gr.Markdown("# CIFAR-10 Training Interface Documentation")
552
+ with gr.Row():
553
+ gr.Markdown(markdown_content) # Can be collapesed in VSCode to hide paragraphs from view. Vscode can also wrap text.
554
+
555
+ ### LAUNCH APP
556
+
557
+ if __name__ == '__main__':
558
+ mainApp = gr.TabbedInterface([functionApp, documentationApp], ["Welcome", "Documentation"], theme=theme)
559
+ mainApp.queue()
560
+ mainApp.launch()