mp-02 commited on
Commit
b248a87
·
verified ·
1 Parent(s): 98f17a7

Update cord_inference.py

Browse files
Files changed (1) hide show
  1. cord_inference.py +4 -0
cord_inference.py CHANGED
@@ -30,6 +30,8 @@ def prediction(image):
30
 
31
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
32
  token_boxes = encoding.bbox.squeeze().tolist()
 
 
33
 
34
  inp_ids = encoding.input_ids.squeeze().tolist()
35
  inp_words = [tokenizer.decode(i) for i in inp_ids]
@@ -39,6 +41,7 @@ def prediction(image):
39
 
40
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
41
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
 
42
  true_words = []
43
 
44
  for id, i in enumerate(inp_words):
@@ -50,6 +53,7 @@ def prediction(image):
50
  true_predictions = true_predictions[1:-1]
51
  true_boxes = true_boxes[1:-1]
52
  true_words = true_words[1:-1]
 
53
 
54
  preds = []
55
  l_words = []
 
30
 
31
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
32
  token_boxes = encoding.bbox.squeeze().tolist()
33
+ probabilities = torch.softmax(outputs.logits, dim=-1)
34
+ confidence_scores = probabilities.max(-1).values.squeeze().tolist()
35
 
36
  inp_ids = encoding.input_ids.squeeze().tolist()
37
  inp_words = [tokenizer.decode(i) for i in inp_ids]
 
41
 
42
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
43
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
44
+ true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
45
  true_words = []
46
 
47
  for id, i in enumerate(inp_words):
 
53
  true_predictions = true_predictions[1:-1]
54
  true_boxes = true_boxes[1:-1]
55
  true_words = true_words[1:-1]
56
+ true_confidence_scores = true_confidence_scores[1:-1]
57
 
58
  preds = []
59
  l_words = []