Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +34 -6
inference.py
CHANGED
@@ -14,6 +14,34 @@ label2id = model.config.label2id
|
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
model.to(device)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def prediction(image):
|
19 |
|
@@ -55,18 +83,18 @@ def prediction(image):
|
|
55 |
|
56 |
for i, conf in enumerate(true_confidence_scores):
|
57 |
if conf < 0.6 :
|
58 |
-
true_predictions[i] = "O"
|
|
|
59 |
|
60 |
-
d =
|
61 |
-
|
|
|
62 |
if i not in d.keys():
|
63 |
d[i] = true_words[id]
|
64 |
else:
|
65 |
d[i] = d[i] + ", " + true_words[id]
|
66 |
d = {k: v.strip() for (k, v) in d.items()}
|
67 |
-
d.pop("O")
|
68 |
-
|
69 |
-
# TODO:process the json
|
70 |
|
71 |
draw = ImageDraw.Draw(image, "RGBA")
|
72 |
font = ImageFont.load_default()
|
|
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
model.to(device)
|
16 |
|
17 |
+
import json
|
18 |
+
|
19 |
+
def token2json(words, labels):
|
20 |
+
result = []
|
21 |
+
current_entity = None
|
22 |
+
|
23 |
+
for token, label in zip(words, labels):
|
24 |
+
if label.startswith("B-"):
|
25 |
+
if current_entity:
|
26 |
+
result.append(current_entity)
|
27 |
+
current_entity = {"type": label[2:], "text": token}
|
28 |
+
elif label.startswith("I-"):
|
29 |
+
if current_entity and current_entity["type"] == label[2:]:
|
30 |
+
current_entity["text"] += " " + token
|
31 |
+
else:
|
32 |
+
if current_entity:
|
33 |
+
result.append(current_entity)
|
34 |
+
current_entity = {"type": label[2:], "text": token}
|
35 |
+
else: # "O" label
|
36 |
+
if current_entity:
|
37 |
+
result.append(current_entity)
|
38 |
+
current_entity = None
|
39 |
+
|
40 |
+
if current_entity:
|
41 |
+
result.append(current_entity)
|
42 |
+
|
43 |
+
return json.dumps(result, ensure_ascii=False, indent=2)
|
44 |
+
|
45 |
|
46 |
def prediction(image):
|
47 |
|
|
|
83 |
|
84 |
for i, conf in enumerate(true_confidence_scores):
|
85 |
if conf < 0.6 :
|
86 |
+
true_predictions[i] = "O"
|
87 |
+
|
88 |
|
89 |
+
d = token2json(true_words, true_predictions)
|
90 |
+
|
91 |
+
"""for id, i in enumerate(true_predictions):
|
92 |
if i not in d.keys():
|
93 |
d[i] = true_words[id]
|
94 |
else:
|
95 |
d[i] = d[i] + ", " + true_words[id]
|
96 |
d = {k: v.strip() for (k, v) in d.items()}
|
97 |
+
d.pop("O")"""
|
|
|
|
|
98 |
|
99 |
draw = ImageDraw.Draw(image, "RGBA")
|
100 |
font = ImageFont.load_default()
|