mehdiabruee commited on
Commit
19f06d7
·
1 Parent(s): 9adb7eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -134,16 +134,22 @@ model = Generator(3, 32, 3, 4).cpu() # input_dim, num_filter, output_dim, num_re
134
  model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
135
  print(model)
136
  model.eval()
 
 
137
 
138
  totensor = torchvision.transforms.ToTensor()
139
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
140
  topilimage = torchvision.transforms.ToPILImage()
141
 
142
- def predict(input):
143
- im = normalize_fn(totensor(input))
144
- print(im.shape)
145
- preds = model(im.unsqueeze(0))/2 + 0.5
146
- print(preds.shape)
147
- return topilimage(preds.squeeze(0).detach())
 
 
 
 
148
 
149
  gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
 
134
  model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
135
  print(model)
136
  model.eval()
137
+ model_2 = Generator(3, 32, 3, 4)
138
+ model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
139
 
140
  totensor = torchvision.transforms.ToTensor()
141
  normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
142
  topilimage = torchvision.transforms.ToPILImage()
143
 
144
+ def predict(input_1, input_2):
145
+ im1 = normalize_fn(totensor(input))
146
+ print(im1.shape)
147
+ preds1 = model(im1.unsqueeze(0))/2 + 0.5
148
+ print(preds1.shape)
149
+ im2 = normalize_fn(totensor(input))
150
+ print(im2.shape)
151
+ preds2 = model_2(im1.unsqueeze(0))/2 + 0.5
152
+ print(preds2.shape)
153
+ return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
154
 
155
  gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()