mehdiabruee commited on
Commit
76c8c49
·
1 Parent(s): 23770b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -136,19 +136,21 @@ print(model)
136
  model.eval()
137
  model_2 = Generator(3, 32, 3, 4)
138
  model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
 
139
 
140
  totensor = torchvision.transforms.ToTensor()
141
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
142
  topilimage = torchvision.transforms.ToPILImage()
143
 
144
- def predict(input_1, input_2):
145
- im1 = normalize_fn(totensor(input))
146
  print(im1.shape)
147
  preds1 = model(im1.unsqueeze(0))/2 + 0.5
148
  print(preds1.shape)
149
- im2 = normalize_fn(totensor(input))
 
150
  print(im2.shape)
151
- preds2 = model_2(im1.unsqueeze(0))/2 + 0.5
152
  print(preds2.shape)
153
  return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
154
 
 
136
  model.eval()
137
  model_2 = Generator(3, 32, 3, 4)
138
  model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
139
+ model_2.eval()
140
 
141
  totensor = torchvision.transforms.ToTensor()
142
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143
  topilimage = torchvision.transforms.ToPILImage()
144
 
145
+ def predict(input_1):
146
+ im1 = normalize_fn(totensor(input_1))
147
  print(im1.shape)
148
  preds1 = model(im1.unsqueeze(0))/2 + 0.5
149
  print(preds1.shape)
150
+
151
+ im2 = normalize_fn(totensor(input_2))
152
  print(im2.shape)
153
+ preds2 = model(im2.unsqueeze(0))/2 + 0.5
154
  print(preds2.shape)
155
  return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
156