drhead commited on
Commit
7130696
·
verified ·
1 Parent(s): fbd5ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -141,7 +141,8 @@ def run_classifier(image: Image.Image, threshold):
141
  tensor = transform(img).unsqueeze(0)
142
 
143
  with torch.no_grad():
144
- probits = model(tensor)[0] # type: torch.Tensor
 
145
  values, indices = probits.cpu().topk(250)
146
 
147
  tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
@@ -175,7 +176,8 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
175
  handle_forward = model.norm.register_forward_hook(hook_forward)
176
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
177
 
178
- probits = model(tensor)[0]
 
179
 
180
  model.zero_grad()
181
  probits[target_tag_index].backward(retain_graph=True)
 
141
  tensor = transform(img).unsqueeze(0)
142
 
143
  with torch.no_grad():
144
+ logits = model(tensor)
145
+ probits = torch.nn.functional.sigmoid(logits[0])
146
  values, indices = probits.cpu().topk(250)
147
 
148
  tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
 
176
  handle_forward = model.norm.register_forward_hook(hook_forward)
177
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
178
 
179
+ logits = model(tensor)
180
+ probits = torch.nn.functional.sigmoid(logits[0])
181
 
182
  model.zero_grad()
183
  probits[target_tag_index].backward(retain_graph=True)