rahulvenkk commited on
Commit
89022d9
Β·
1 Parent(s): 37eff47

modified app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
@@ -85,6 +86,12 @@ import os
85
  # preloaded_images = load_preuploaded_images()
86
  #
87
  # print("Preloaded images:", preloaded_images)
 
 
 
 
 
 
88
  with gr.Blocks() as demo:
89
  with gr.Row():
90
  gr.Markdown('''# Generate interventions!πŸš€
@@ -268,8 +275,9 @@ with gr.Blocks() as demo:
268
 
269
  # Imagenet-normalize the inputs (standardization)
270
  x = utils.imagenet_normalize(x).to(device)
271
- with torch.no_grad():
272
- counterfactual = model.get_counterfactual(x, points)
 
273
 
274
  counterfactual = counterfactual.squeeze()
275
 
@@ -286,4 +294,4 @@ with gr.Blocks() as demo:
286
 
287
 
288
  # Launch the app
289
- demo.queue().launch(inbrowser=True, share=True)
 
1
+ import spaces
2
  import cv2
3
  import numpy as np
4
  import gradio as gr
 
86
  # preloaded_images = load_preuploaded_images()
87
  #
88
  # print("Preloaded images:", preloaded_images)
89
+ @spaces.GPU
90
+ def get_c(x, points):
91
+ with torch.no_grad():
92
+ counterfactual = model.get_counterfactual(x, points)
93
+ return counterfactual
94
+
95
  with gr.Blocks() as demo:
96
  with gr.Row():
97
  gr.Markdown('''# Generate interventions!πŸš€
 
275
 
276
  # Imagenet-normalize the inputs (standardization)
277
  x = utils.imagenet_normalize(x).to(device)
278
+
279
+ counterfactual = get_c(x, points)
280
+
281
 
282
  counterfactual = counterfactual.squeeze()
283
 
 
294
 
295
 
296
  # Launch the app
297
+ demo.queue().launch()