orhir commited on
Commit
7b49f58
·
verified ·
1 Parent(s): 6036218

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +8 -5
gradio_utils/utils.py CHANGED
@@ -111,10 +111,13 @@ def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_
111
  @spaces.GPU(duration=30)
112
  def estimate(model, data):
113
  model.cuda()
 
 
 
 
114
  with torch.no_grad():
115
  return model(**data)
116
 
117
- @spaces.GPU(duration=30)
118
  def process(query_img, state,
119
  cfg_path='configs/test/1shot_split1.py',
120
  checkpoint_path='ckpt/1shot_split1.pth'):
@@ -159,10 +162,10 @@ def process(query_img, state,
159
  torch.tensor(target_weight_s).float()[None])
160
 
161
  data = {
162
- 'img_s': [support_img.to(device)],
163
- 'img_q': q_img.to(device),
164
- 'target_s': [target_s.to(device)],
165
- 'target_weight_s': [target_weight_s.to(device)],
166
  'target_q': None,
167
  'target_weight_q': None,
168
  'return_loss': False,
 
111
  @spaces.GPU(duration=30)
112
  def estimate(model, data):
113
  model.cuda()
114
+ data['img_s'] = [s.cuda() for s in data['img_s']]
115
+ data['img_q'] = data['img_q'].cuda()
116
+ data['target_s'] = [s.cuda() for s in data['target_s']]
117
+ data['target_weight_s'] = [s.cuda() for s in data['target_weight_s']]
118
  with torch.no_grad():
119
  return model(**data)
120
 
 
121
  def process(query_img, state,
122
  cfg_path='configs/test/1shot_split1.py',
123
  checkpoint_path='ckpt/1shot_split1.pth'):
 
162
  torch.tensor(target_weight_s).float()[None])
163
 
164
  data = {
165
+ 'img_s': [support_img],
166
+ 'img_q': q_img,
167
+ 'target_s': [target_s],
168
+ 'target_weight_s': [target_weight_s],
169
  'target_q': None,
170
  'target_weight_q': None,
171
  'return_loss': False,