Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +33 -82
inference.py
CHANGED
@@ -16,98 +16,49 @@ model.to(device)
|
|
16 |
|
17 |
import json
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
current_text = []
|
23 |
-
|
24 |
-
for word, label in zip(words, labels):
|
25 |
-
if label.startswith("B-"):
|
26 |
-
if current_entity:
|
27 |
-
result[current_entity] = " ".join(current_text).strip()
|
28 |
-
current_text = []
|
29 |
-
current_entity = label[2:].lower()
|
30 |
-
current_text = [word]
|
31 |
-
elif label.startswith("I-"):
|
32 |
-
if current_entity == label[2:].lower():
|
33 |
-
current_text.append(word)
|
34 |
-
else:
|
35 |
-
# Gestione di sequenze I- non precedute da B-
|
36 |
-
if current_entity:
|
37 |
-
result[current_entity] = " ".join(current_text).strip()
|
38 |
-
current_entity = label[2:].lower()
|
39 |
-
current_text = [word]
|
40 |
-
else: # Label "O"
|
41 |
-
if current_entity:
|
42 |
-
result[current_entity] = " ".join(current_text).strip()
|
43 |
-
current_entity = None
|
44 |
-
current_text = []
|
45 |
-
|
46 |
-
# Aggiunge l'ultima entità se presente
|
47 |
-
if current_entity:
|
48 |
-
result[current_entity] = " ".join(current_text).strip()
|
49 |
-
|
50 |
-
return json.dumps(result, ensure_ascii=False, indent=2)
|
51 |
-
|
52 |
-
|
53 |
-
def prediction(image):
|
54 |
-
|
55 |
-
boxes, words = OCR(image)
|
56 |
-
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
|
57 |
-
offset_mapping = encoding.pop('offset_mapping')
|
58 |
-
|
59 |
-
for k, v in encoding.items():
|
60 |
-
encoding[k] = v.to(device)
|
61 |
|
|
|
|
|
62 |
outputs = model(**encoding)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
68 |
-
|
69 |
-
inp_ids = encoding.input_ids.squeeze().tolist()
|
70 |
-
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
71 |
-
|
72 |
-
width, height = image.size
|
73 |
-
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
true_words.append(i)
|
83 |
-
else:
|
84 |
-
true_words[-1] = true_words[-1]+i
|
85 |
-
|
86 |
-
true_predictions = true_predictions[1:-1]
|
87 |
-
true_boxes = true_boxes[1:-1]
|
88 |
-
true_words = true_words[1:-1]
|
89 |
-
true_confidence_scores = true_confidence_scores[1:-1]
|
90 |
-
|
91 |
-
for i, conf in enumerate(true_confidence_scores):
|
92 |
-
if conf < 0.6 :
|
93 |
-
true_predictions[i] = "O"
|
94 |
|
|
|
|
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
else:
|
102 |
-
d[i] = d[i] + ", " + true_words[id]
|
103 |
-
d = {k: v.strip() for (k, v) in d.items()}
|
104 |
-
d.pop("O")"""
|
105 |
|
106 |
-
|
107 |
-
|
108 |
|
109 |
-
|
110 |
draw.rectangle(box)
|
111 |
draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
|
112 |
|
113 |
-
|
|
|
|
16 |
|
17 |
import json
|
18 |
|
19 |
+
boxes, words = OCR(image)
|
20 |
+
# Preprocessa l'immagine e il testo con il processore di LayoutLMv3
|
21 |
+
encoding = processor(image, words=words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Esegui l'inferenza con il modello fine-tuned
|
24 |
+
with torch.no_grad():
|
25 |
outputs = model(**encoding)
|
26 |
|
27 |
+
# Ottieni le etichette previste dal modello
|
28 |
+
logits = outputs.logits
|
29 |
+
predicted_ids = logits.argmax(-1).squeeze().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
32 |
+
token_boxes = encoding.bbox.squeeze().tolist()
|
33 |
+
probabilities = torch.softmax(outputs.logits, dim=-1)
|
34 |
+
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
35 |
|
36 |
+
# Mappa gli ID predetti nelle etichette di classificazione
|
37 |
+
labels = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Funzione per creare l'output JSON in formato CORD-like
|
40 |
+
def create_json_output(words, labels, boxes):
|
41 |
+
output = []
|
42 |
|
43 |
+
for word, label, box in zip(words, labels, boxes):
|
44 |
+
# Considera solo le etichette rilevanti (escludendo "O")
|
45 |
+
if label != "O":
|
46 |
+
output.append({
|
47 |
+
"text": word,
|
48 |
+
"category": label, # la categoria predetta dal modello (es. "B-PRODUCT", "B-PRICE", "B-TOTAL")
|
49 |
+
"bounding_box": box # le coordinate di bounding box per la parola
|
50 |
+
})
|
51 |
|
52 |
+
# Converti in JSON
|
53 |
+
json_output = json.dumps(output, indent=4)
|
54 |
+
return json_output
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
# Crea il JSON usando i risultati ottenuti
|
57 |
+
json_result = create_json_output(words, labels, boxes)
|
58 |
|
59 |
+
for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
|
60 |
draw.rectangle(box)
|
61 |
draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
|
62 |
|
63 |
+
return image, json_result
|
64 |
+
|