Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_utils/utils.py
Browse files- gradio_utils/utils.py +8 -5
gradio_utils/utils.py
CHANGED
@@ -111,10 +111,13 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
|
|
111 |
@spaces.GPU(duration=30)
|
112 |
def estimate(model, data):
|
113 |
model.cuda()
|
|
|
|
|
|
|
|
|
114 |
with torch.no_grad():
|
115 |
return model(**data)
|
116 |
|
117 |
-
@spaces.GPU(duration=30)
|
118 |
def process(query_img, state,
|
119 |
cfg_path='configs/test/1shot_split1.py',
|
120 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
@@ -159,10 +162,10 @@ def process(query_img, state,
|
|
159 |
torch.tensor(target_weight_s).float()[None])
|
160 |
|
161 |
data = {
|
162 |
-
'img_s': [support_img
|
163 |
-
'img_q': q_img
|
164 |
-
'target_s': [target_s
|
165 |
-
'target_weight_s': [target_weight_s
|
166 |
'target_q': None,
|
167 |
'target_weight_q': None,
|
168 |
'return_loss': False,
|
|
|
111 |
@spaces.GPU(duration=30)
|
112 |
def estimate(model, data):
|
113 |
model.cuda()
|
114 |
+
data['img_s'] = [s.cuda() for s in data['img_s']]
|
115 |
+
data['img_q'] = data['img_q'].cuda()
|
116 |
+
data['target_s'] = [s.cuda() for s in data['target_s']]
|
117 |
+
data['target_weight_s'] = [s.cuda() for s in data['target_weight_s']]
|
118 |
with torch.no_grad():
|
119 |
return model(**data)
|
120 |
|
|
|
121 |
def process(query_img, state,
|
122 |
cfg_path='configs/test/1shot_split1.py',
|
123 |
checkpoint_path='ckpt/1shot_split1.pth'):
|
|
|
162 |
torch.tensor(target_weight_s).float()[None])
|
163 |
|
164 |
data = {
|
165 |
+
'img_s': [support_img],
|
166 |
+
'img_q': q_img,
|
167 |
+
'target_s': [target_s],
|
168 |
+
'target_weight_s': [target_weight_s],
|
169 |
'target_q': None,
|
170 |
'target_weight_q': None,
|
171 |
'return_loss': False,
|