rahulvenkk
commited on
Commit
Β·
89022d9
1
Parent(s):
37eff47
modified app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
import gradio as gr
|
@@ -85,6 +86,12 @@ import os
|
|
85 |
# preloaded_images = load_preuploaded_images()
|
86 |
#
|
87 |
# print("Preloaded images:", preloaded_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with gr.Blocks() as demo:
|
89 |
with gr.Row():
|
90 |
gr.Markdown('''# Generate interventions!π
|
@@ -268,8 +275,9 @@ with gr.Blocks() as demo:
|
|
268 |
|
269 |
# Imagenet-normalize the inputs (standardization)
|
270 |
x = utils.imagenet_normalize(x).to(device)
|
271 |
-
|
272 |
-
|
|
|
273 |
|
274 |
counterfactual = counterfactual.squeeze()
|
275 |
|
@@ -286,4 +294,4 @@ with gr.Blocks() as demo:
|
|
286 |
|
287 |
|
288 |
# Launch the app
|
289 |
-
demo.queue().launch(
|
|
|
1 |
+
import spaces
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
|
|
86 |
# preloaded_images = load_preuploaded_images()
|
87 |
#
|
88 |
# print("Preloaded images:", preloaded_images)
|
89 |
+
@spaces.GPU
|
90 |
+
def get_c(x, points):
|
91 |
+
with torch.no_grad():
|
92 |
+
counterfactual = model.get_counterfactual(x, points)
|
93 |
+
return counterfactual
|
94 |
+
|
95 |
with gr.Blocks() as demo:
|
96 |
with gr.Row():
|
97 |
gr.Markdown('''# Generate interventions!π
|
|
|
275 |
|
276 |
# Imagenet-normalize the inputs (standardization)
|
277 |
x = utils.imagenet_normalize(x).to(device)
|
278 |
+
|
279 |
+
counterfactual = get_c(x, points)
|
280 |
+
|
281 |
|
282 |
counterfactual = counterfactual.squeeze()
|
283 |
|
|
|
294 |
|
295 |
|
296 |
# Launch the app
|
297 |
+
demo.queue().launch()
|