Spaces:
Running
Running
Update cord_inference.py
Browse files- cord_inference.py +4 -0
cord_inference.py
CHANGED
@@ -30,6 +30,8 @@ def prediction(image):
|
|
30 |
|
31 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
32 |
token_boxes = encoding.bbox.squeeze().tolist()
|
|
|
|
|
33 |
|
34 |
inp_ids = encoding.input_ids.squeeze().tolist()
|
35 |
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
@@ -39,6 +41,7 @@ def prediction(image):
|
|
39 |
|
40 |
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
41 |
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
|
|
42 |
true_words = []
|
43 |
|
44 |
for id, i in enumerate(inp_words):
|
@@ -50,6 +53,7 @@ def prediction(image):
|
|
50 |
true_predictions = true_predictions[1:-1]
|
51 |
true_boxes = true_boxes[1:-1]
|
52 |
true_words = true_words[1:-1]
|
|
|
53 |
|
54 |
preds = []
|
55 |
l_words = []
|
|
|
30 |
|
31 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
32 |
token_boxes = encoding.bbox.squeeze().tolist()
|
33 |
+
probabilities = torch.softmax(outputs.logits, dim=-1)
|
34 |
+
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
35 |
|
36 |
inp_ids = encoding.input_ids.squeeze().tolist()
|
37 |
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
|
|
41 |
|
42 |
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
43 |
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
44 |
+
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
|
45 |
true_words = []
|
46 |
|
47 |
for id, i in enumerate(inp_words):
|
|
|
53 |
true_predictions = true_predictions[1:-1]
|
54 |
true_boxes = true_boxes[1:-1]
|
55 |
true_words = true_words[1:-1]
|
56 |
+
true_confidence_scores = true_confidence_scores[1:-1]
|
57 |
|
58 |
preds = []
|
59 |
l_words = []
|