Spaces:
Runtime error
Runtime error
Commit
·
a2621e6
1
Parent(s):
3822ccd
Update app.py
Browse files
app.py
CHANGED
@@ -134,24 +134,16 @@ model = Generator(3, 32, 3, 4).cpu() # input_dim, num_filter, output_dim, num_re
|
|
134 |
model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
135 |
print(model)
|
136 |
model.eval()
|
137 |
-
model_2 = Generator(3, 32, 3, 4)
|
138 |
-
model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
139 |
-
model_2.eval()
|
140 |
|
141 |
totensor = torchvision.transforms.ToTensor()
|
142 |
normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
143 |
topilimage = torchvision.transforms.ToPILImage()
|
144 |
|
145 |
-
def predict(input_1
|
146 |
im1 = normalize_fn(totensor(input_1))
|
147 |
print(im1.shape)
|
148 |
preds1 = model(im1.unsqueeze(0))/2 + 0.5
|
149 |
print(preds1.shape)
|
150 |
-
|
151 |
-
im2 = normalize_fn(totensor(input_2))
|
152 |
-
print(im2.shape)
|
153 |
-
preds2 = model(im2.unsqueeze(0))/2 + 0.5
|
154 |
-
print(preds2.shape)
|
155 |
-
return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
|
156 |
|
157 |
gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
|
|
|
134 |
model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
135 |
print(model)
|
136 |
model.eval()
|
|
|
|
|
|
|
137 |
|
138 |
totensor = torchvision.transforms.ToTensor()
|
139 |
normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
140 |
topilimage = torchvision.transforms.ToPILImage()
|
141 |
|
142 |
+
def predict(input_1):
|
143 |
im1 = normalize_fn(totensor(input_1))
|
144 |
print(im1.shape)
|
145 |
preds1 = model(im1.unsqueeze(0))/2 + 0.5
|
146 |
print(preds1.shape)
|
147 |
+
return topilimage(preds1.squeeze(0).detach())
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
|