rahulvenkk
commited on
Commit
·
110d56f
1
Parent(s):
a45652e
modified app.py gradio cuda
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
26 |
|
27 |
|
28 |
# Load CWM 3-frame model (automatically download pre-trained checkpoint)
|
29 |
-
model = model_factory.load_model('vitb_8x8patch_3frames')
|
30 |
|
31 |
model.requires_grad_(False)
|
32 |
model.eval()
|
@@ -91,7 +91,7 @@ import os
|
|
91 |
# print("Preloaded images:", preloaded_images)
|
92 |
@spaces.GPU
|
93 |
def get_c(x, points):
|
94 |
-
x = utils.imagenet_normalize(x)
|
95 |
with torch.no_grad():
|
96 |
counterfactual = model.get_counterfactual(x, points)
|
97 |
return counterfactual
|
|
|
26 |
|
27 |
|
28 |
# Load CWM 3-frame model (automatically download pre-trained checkpoint)
|
29 |
+
model = model_factory.load_model('vitb_8x8patch_3frames')#.to(device)
|
30 |
|
31 |
model.requires_grad_(False)
|
32 |
model.eval()
|
|
|
91 |
# print("Preloaded images:", preloaded_images)
|
92 |
@spaces.GPU
|
93 |
def get_c(x, points):
|
94 |
+
x = utils.imagenet_normalize(x)#.to(device)
|
95 |
with torch.no_grad():
|
96 |
counterfactual = model.get_counterfactual(x, points)
|
97 |
return counterfactual
|