mp-02 commited on
Commit
4093517
·
verified ·
1 Parent(s): 3f079d1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +34 -6
inference.py CHANGED
@@ -14,6 +14,34 @@ label2id = model.config.label2id
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model.to(device)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def prediction(image):
19
 
@@ -55,18 +83,18 @@ def prediction(image):
55
 
56
  for i, conf in enumerate(true_confidence_scores):
57
  if conf < 0.6 :
58
- true_predictions[i] = "O"
 
59
 
60
- d = {}
61
- for id, i in enumerate(true_predictions):
 
62
  if i not in d.keys():
63
  d[i] = true_words[id]
64
  else:
65
  d[i] = d[i] + ", " + true_words[id]
66
  d = {k: v.strip() for (k, v) in d.items()}
67
- d.pop("O")
68
-
69
- # TODO:process the json
70
 
71
  draw = ImageDraw.Draw(image, "RGBA")
72
  font = ImageFont.load_default()
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model.to(device)
16
 
17
+ import json
18
+
19
+ def token2json(words, labels):
20
+ result = []
21
+ current_entity = None
22
+
23
+ for token, label in zip(words, labels):
24
+ if label.startswith("B-"):
25
+ if current_entity:
26
+ result.append(current_entity)
27
+ current_entity = {"type": label[2:], "text": token}
28
+ elif label.startswith("I-"):
29
+ if current_entity and current_entity["type"] == label[2:]:
30
+ current_entity["text"] += " " + token
31
+ else:
32
+ if current_entity:
33
+ result.append(current_entity)
34
+ current_entity = {"type": label[2:], "text": token}
35
+ else: # "O" label
36
+ if current_entity:
37
+ result.append(current_entity)
38
+ current_entity = None
39
+
40
+ if current_entity:
41
+ result.append(current_entity)
42
+
43
+ return json.dumps(result, ensure_ascii=False, indent=2)
44
+
45
 
46
  def prediction(image):
47
 
 
83
 
84
  for i, conf in enumerate(true_confidence_scores):
85
  if conf < 0.6 :
86
+ true_predictions[i] = "O"
87
+
88
 
89
+ d = token2json(true_words, true_predictions)
90
+
91
+ """for id, i in enumerate(true_predictions):
92
  if i not in d.keys():
93
  d[i] = true_words[id]
94
  else:
95
  d[i] = d[i] + ", " + true_words[id]
96
  d = {k: v.strip() for (k, v) in d.items()}
97
+ d.pop("O")"""
 
 
98
 
99
  draw = ImageDraw.Draw(image, "RGBA")
100
  font = ImageFont.load_default()