Spaces:
Runtime error
Runtime error
Commit
·
19f06d7
1
Parent(s):
9adb7eb
Update app.py
Browse files
app.py
CHANGED
@@ -134,16 +134,22 @@ model = Generator(3, 32, 3, 4).cpu() # input_dim, num_filter, output_dim, num_re
|
|
134 |
model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
135 |
print(model)
|
136 |
model.eval()
|
|
|
|
|
137 |
|
138 |
totensor = torchvision.transforms.ToTensor()
|
139 |
normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
140 |
topilimage = torchvision.transforms.ToPILImage()
|
141 |
|
142 |
-
def predict(
|
143 |
-
|
144 |
-
print(
|
145 |
-
|
146 |
-
print(
|
147 |
-
|
|
|
|
|
|
|
|
|
148 |
|
149 |
gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
|
|
|
134 |
model.load_state_dict(torch.load('G_A_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
135 |
print(model)
|
136 |
model.eval()
|
137 |
+
model_2 = Generator(3, 32, 3, 4)
|
138 |
+
model_2.load_state_dict(torch.load('G_B_HW4_SAVE.pt',map_location=torch.device('cpu')))
|
139 |
|
140 |
totensor = torchvision.transforms.ToTensor()
|
141 |
normalize_fn = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
142 |
topilimage = torchvision.transforms.ToPILImage()
|
143 |
|
144 |
+
def predict(input_1, input_2):
|
145 |
+
im1 = normalize_fn(totensor(input))
|
146 |
+
print(im1.shape)
|
147 |
+
preds1 = model(im1.unsqueeze(0))/2 + 0.5
|
148 |
+
print(preds1.shape)
|
149 |
+
im2 = normalize_fn(totensor(input))
|
150 |
+
print(im2.shape)
|
151 |
+
preds2 = model_2(im1.unsqueeze(0))/2 + 0.5
|
152 |
+
print(preds2.shape)
|
153 |
+
return topilimage(preds1.squeeze(0).detach()), topilimage(preds2.squeeze(0).detach())
|
154 |
|
155 |
gr_interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(256,256)), outputs="image", title='Emoji_CycleGAN').launch()
|