mp-02 commited on
Commit
0f12594
·
verified ·
1 Parent(s): f3df04e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +33 -82
inference.py CHANGED
@@ -16,98 +16,49 @@ model.to(device)
16
 
17
  import json
18
 
19
- def token2json(words, labels):
20
- result = {}
21
- current_entity = None
22
- current_text = []
23
-
24
- for word, label in zip(words, labels):
25
- if label.startswith("B-"):
26
- if current_entity:
27
- result[current_entity] = " ".join(current_text).strip()
28
- current_text = []
29
- current_entity = label[2:].lower()
30
- current_text = [word]
31
- elif label.startswith("I-"):
32
- if current_entity == label[2:].lower():
33
- current_text.append(word)
34
- else:
35
- # Gestione di sequenze I- non precedute da B-
36
- if current_entity:
37
- result[current_entity] = " ".join(current_text).strip()
38
- current_entity = label[2:].lower()
39
- current_text = [word]
40
- else: # Label "O"
41
- if current_entity:
42
- result[current_entity] = " ".join(current_text).strip()
43
- current_entity = None
44
- current_text = []
45
-
46
- # Aggiunge l'ultima entità se presente
47
- if current_entity:
48
- result[current_entity] = " ".join(current_text).strip()
49
-
50
- return json.dumps(result, ensure_ascii=False, indent=2)
51
-
52
-
53
- def prediction(image):
54
-
55
- boxes, words = OCR(image)
56
- encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
57
- offset_mapping = encoding.pop('offset_mapping')
58
-
59
- for k, v in encoding.items():
60
- encoding[k] = v.to(device)
61
 
 
 
62
  outputs = model(**encoding)
63
 
64
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
65
- token_boxes = encoding.bbox.squeeze().tolist()
66
- probabilities = torch.softmax(outputs.logits, dim=-1)
67
- confidence_scores = probabilities.max(-1).values.squeeze().tolist()
68
-
69
- inp_ids = encoding.input_ids.squeeze().tolist()
70
- inp_words = [tokenizer.decode(i) for i in inp_ids]
71
-
72
- width, height = image.size
73
- is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
74
 
75
- true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
76
- true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
77
- true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
78
- true_words = []
79
 
80
- for id, i in enumerate(inp_words):
81
- if not is_subword[id]:
82
- true_words.append(i)
83
- else:
84
- true_words[-1] = true_words[-1]+i
85
-
86
- true_predictions = true_predictions[1:-1]
87
- true_boxes = true_boxes[1:-1]
88
- true_words = true_words[1:-1]
89
- true_confidence_scores = true_confidence_scores[1:-1]
90
-
91
- for i, conf in enumerate(true_confidence_scores):
92
- if conf < 0.6 :
93
- true_predictions[i] = "O"
94
 
 
 
 
95
 
96
- d = token2json(true_words, true_predictions)
 
 
 
 
 
 
 
97
 
98
- """for id, i in enumerate(true_predictions):
99
- if i not in d.keys():
100
- d[i] = true_words[id]
101
- else:
102
- d[i] = d[i] + ", " + true_words[id]
103
- d = {k: v.strip() for (k, v) in d.items()}
104
- d.pop("O")"""
105
 
106
- draw = ImageDraw.Draw(image, "RGBA")
107
- font = ImageFont.load_default()
108
 
109
- for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
110
  draw.rectangle(box)
111
  draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
112
 
113
- return image, d
 
 
16
 
17
  import json
18
 
19
+ boxes, words = OCR(image)
20
+ # Preprocessa l'immagine e il testo con il processore di LayoutLMv3
21
+ encoding = processor(image, words=words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Esegui l'inferenza con il modello fine-tuned
24
+ with torch.no_grad():
25
  outputs = model(**encoding)
26
 
27
+ # Ottieni le etichette previste dal modello
28
+ logits = outputs.logits
29
+ predicted_ids = logits.argmax(-1).squeeze().tolist()
 
 
 
 
 
 
 
30
 
31
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
32
+ token_boxes = encoding.bbox.squeeze().tolist()
33
+ probabilities = torch.softmax(outputs.logits, dim=-1)
34
+ confidence_scores = probabilities.max(-1).values.squeeze().tolist()
35
 
36
+ # Mappa gli ID predetti nelle etichette di classificazione
37
+ labels = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Funzione per creare l'output JSON in formato CORD-like
40
+ def create_json_output(words, labels, boxes):
41
+ output = []
42
 
43
+ for word, label, box in zip(words, labels, boxes):
44
+ # Considera solo le etichette rilevanti (escludendo "O")
45
+ if label != "O":
46
+ output.append({
47
+ "text": word,
48
+ "category": label, # la categoria predetta dal modello (es. "B-PRODUCT", "B-PRICE", "B-TOTAL")
49
+ "bounding_box": box # le coordinate di bounding box per la parola
50
+ })
51
 
52
+ # Converti in JSON
53
+ json_output = json.dumps(output, indent=4)
54
+ return json_output
 
 
 
 
55
 
56
+ # Crea il JSON usando i risultati ottenuti
57
+ json_result = create_json_output(words, labels, boxes)
58
 
59
+ for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
60
  draw.rectangle(box)
61
  draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
62
 
63
+ return image, json_result
64
+