kritsg commited on
Commit
38db7e9
·
1 Parent(s): b71b9d2

testing if cuda works

Browse files
Files changed (2) hide show
  1. bayes/models.py +4 -5
  2. data/mnist/mnist_model.py +1 -1
bayes/models.py CHANGED
@@ -20,7 +20,6 @@ from sklearn.model_selection import train_test_split
20
  import torch
21
  from torchvision import models, transforms
22
 
23
- from efficientnet.tfkeras import EfficientNetB0
24
  from data.mnist.mnist_model import Net
25
 
26
  def get_xtrain(segs):
@@ -41,9 +40,9 @@ def get_xtrain(segs):
41
  def process_imagenet_get_model(data):
42
  """Gets wrapped imagenet model."""
43
  # Get the vgg16 model, used in the experiments
44
- model = EfficientNetB0(weights='imagenet')
45
- # model.eval()
46
- # model.cuda()
47
 
48
  xtest = data['X']
49
  ytest = data['y'].astype(int)
@@ -73,7 +72,7 @@ def process_imagenet_get_model(data):
73
  perturbed_image[segments==i, 1] = background
74
  perturbed_image[segments==i, 2] = background
75
  perturbed_images.append(transf(perturbed_image)[None, :])
76
- perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float()
77
  predictions = []
78
  for q in range(0, perturbed_images.shape[0], batch_size):
79
  predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
 
20
  import torch
21
  from torchvision import models, transforms
22
 
 
23
  from data.mnist.mnist_model import Net
24
 
25
  def get_xtrain(segs):
 
40
  def process_imagenet_get_model(data):
41
  """Gets wrapped imagenet model."""
42
  # Get the vgg16 model, used in the experiments
43
+ model = models.vgg16(pretrained=True)
44
+ model.eval()
45
+ model.cuda()
46
 
47
  xtest = data['X']
48
  ytest = data['y'].astype(int)
 
72
  perturbed_image[segments==i, 1] = background
73
  perturbed_image[segments==i, 2] = background
74
  perturbed_images.append(transf(perturbed_image)[None, :])
75
+ perturbed_images = torch.from_numpy(np.concatenate(perturbed_images, axis=0)).float().cuda()
76
  predictions = []
77
  for q in range(0, perturbed_images.shape[0], batch_size):
78
  predictions.append(softmax(model(perturbed_images[q:q+batch_size])).cpu().detach().numpy())
data/mnist/mnist_model.py CHANGED
@@ -94,7 +94,7 @@ def main():
94
  parser.add_argument('--save-model', action='store_true', default=False,
95
  help='For Saving the current Model')
96
  args = parser.parse_args()
97
- use_cuda = False # not args.no_cuda and torch.cuda.is_available()
98
 
99
  torch.manual_seed(args.seed)
100
 
 
94
  parser.add_argument('--save-model', action='store_true', default=False,
95
  help='For Saving the current Model')
96
  args = parser.parse_args()
97
+ use_cuda = True # not args.no_cuda and torch.cuda.is_available()
98
 
99
  torch.manual_seed(args.seed)
100