MnLgt commited on
Commit
f725299
·
1 Parent(s): a34e3aa
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
  */partially_signed_agreement_1.png
2
 
3
- */**.pyc
 
1
  */partially_signed_agreement_1.png
2
 
3
+ *.pyc
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a Gradio GUI for detecting and classifying signature blocks in document images
3
+ using the SignatureBlockModel. It loads example images from the /assets directory, displays
4
+ bounding boxes in the result image, and shows cropped signature blocks with labels in a separate view.
5
+ """
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import io
12
+ from typing import Tuple
13
+ import os
14
+
15
+ from scripts.signature_blocks import SignatureBlockModel
16
+
17
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
18
+
19
+
20
+ def process_image(image: np.ndarray) -> Tuple[np.ndarray, str, np.ndarray]:
21
+ """
22
+ Process an input image using the SignatureBlockModel.
23
+
24
+ Args:
25
+ image (np.ndarray): Input image as a numpy array.
26
+
27
+ Returns:
28
+ Tuple[np.ndarray, str, np.ndarray]: Processed image, status, and signature crops image.
29
+ """
30
+ # Convert numpy array to PIL Image
31
+ pil_image = Image.fromarray(image)
32
+
33
+ # Initialize the model
34
+ model = SignatureBlockModel(pil_image)
35
+
36
+ # Get processed image with boxes
37
+ image_with_boxes = model.draw_boxes()
38
+
39
+ # Get signature crops
40
+ signature_crops = create_signature_crops(model)
41
+
42
+ # Determine status
43
+ labels = model.get_labels()
44
+ if not labels.any():
45
+ status = "Unsigned"
46
+ elif all(label == 1 for label in labels):
47
+ status = "Fully Executed"
48
+ elif all(label == 2 for label in labels):
49
+ status = "Unsigned"
50
+ else:
51
+ status = "Partially Executed"
52
+
53
+ return np.array(image_with_boxes), status, signature_crops
54
+
55
+
56
+ def resize_crop(crop: np.ndarray, factor=0.5) -> np.ndarray:
57
+ """
58
+ Resize a crop to a target size.
59
+
60
+ Args:
61
+ crop (np.ndarray): Input crop as a numpy array.
62
+ target_size (Tuple[int, int]): Target size for the crop.
63
+
64
+ Returns:
65
+ np.ndarray: Resized crop.
66
+ """
67
+ crop_image = Image.fromarray(crop).convert("RGB")
68
+ crop_size = crop_image.size
69
+ target_size = tuple(int(dim * factor) for dim in crop_size)
70
+ print(f"Target Size: {target_size}")
71
+ crop_image = crop_image.resize(target_size)
72
+ return np.array(crop_image)
73
+
74
+
75
+ def create_signature_crops(model: SignatureBlockModel) -> np.ndarray:
76
+ """
77
+ Create an image with stacked signature crops and labels.
78
+
79
+ Args:
80
+ model (SignatureBlockModel): The initialized SignatureBlockModel.
81
+
82
+ Returns:
83
+ np.ndarray: Image with stacked signature crops and labels.
84
+ """
85
+ boxes = model.get_boxes()
86
+ scores = model.get_scores()
87
+ labels = model.get_labels()
88
+ classes = model.classes
89
+
90
+ # Create a figure with the correct number of subplots
91
+ fig, axes = plt.subplots(len(boxes), 2, figsize=(10, 3 * len(boxes)))
92
+ # plt.subplots_adjust(hspace=0.5, wspace=0.1) # Add space between subplots
93
+
94
+ # Ensure axes is always a 2D array, even with only one box
95
+ if len(boxes) == 1:
96
+ axes = axes.reshape(1, -1)
97
+
98
+ for (ax_label, ax_image), box, label, score in zip(axes, boxes, labels, scores):
99
+ crop = model.extract_box(box)
100
+ crop = resize_crop(crop, 0.7)
101
+
102
+ # Set background color to black for both subplots
103
+ ax_label.set_facecolor("black")
104
+ ax_image.set_facecolor("black")
105
+
106
+ # Add label text
107
+ label_text = f"Label: {classes[label]}\nScore: {score:.2f}"
108
+ ax_label.text(
109
+ 0.05,
110
+ 0.5,
111
+ label_text,
112
+ color="white",
113
+ fontsize=12,
114
+ verticalalignment="center",
115
+ horizontalalignment="left",
116
+ )
117
+ ax_label.axis("off")
118
+
119
+ # Display the crop
120
+ ax_image.imshow(crop)
121
+ ax_image.axis("off")
122
+
123
+ plt.tight_layout()
124
+
125
+ # Convert the matplotlib figure to a PNG image
126
+ buf = io.BytesIO()
127
+ plt.savefig(buf, format="png", facecolor="black", edgecolor="none")
128
+ buf.seek(0)
129
+ signature_crops = np.array(Image.open(buf))
130
+ plt.close(fig)
131
+
132
+ return signature_crops
133
+
134
+
135
+ def load_examples():
136
+ """
137
+ Load example images from the /assets directory.
138
+
139
+ Returns:
140
+ List[List[str]]: List of example image paths.
141
+ """
142
+ examples = []
143
+ for filename in os.listdir(ASSETS_DIR):
144
+ if filename.lower().endswith((".png", ".jpg", ".jpeg")):
145
+ examples.append([os.path.join(ASSETS_DIR, filename)])
146
+ return examples
147
+
148
+
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("# Signature Block Detection")
151
+ gr.Markdown("Upload a document image to detect and classify signature blocks.")
152
+
153
+ with gr.Row():
154
+ input_image = gr.Image(label="Upload Document Image")
155
+ output_image = gr.Image(label="Processed Image")
156
+
157
+ with gr.Row():
158
+ status_box = gr.Textbox(label="Document Status")
159
+ signature_crops = gr.Image(label="Signature Crops")
160
+
161
+ process_btn = gr.Button("Process Image")
162
+
163
+ examples = gr.Examples(
164
+ examples=load_examples(),
165
+ inputs=input_image,
166
+ )
167
+
168
+ process_btn.click(
169
+ fn=process_image,
170
+ inputs=input_image,
171
+ outputs=[output_image, status_box, signature_crops],
172
+ )
173
+
174
+ if __name__ == "__main__":
175
+ demo.launch()
assets/signed_agreement_1.jpg ADDED
assets/signed_agreement_2.png ADDED
assets/signed_agreement_3.jpg ADDED
assets/unsigned_agreement_1.jpg ADDED
assets/unsigned_agreement_2.jpg ADDED
assets/unsigned_agreement_3.png ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gabriel.egg==info
2
+ gradio==4.44.0
3
+ matplotlib==3.8.4
4
+ numpy==2.1.1
5
+ Pillow==10.4.0
6
+ torch==2.0.1
7
+ torchvision==0.15.2
8
+ opencv-python
scripts/execution_status.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from scripts.signature_blocks import SignatureBlockModel
3
+ from typing import List, Any, Tuple
4
+
5
+
6
+ def flatten_list(xss: List[List[Any]]) -> List[Any]:
7
+ return [x for xs in xss for x in xs]
8
+
9
+
10
+ def agreement_status(labels: List[str]) -> str:
11
+ if labels:
12
+ if len(set(labels)) > 1:
13
+ return "Partially Executed"
14
+ elif list(set(labels))[0] == "SIGNED_BLOCK":
15
+ return "Fully Executed"
16
+ elif list(set(labels))[0] == "UNSIGNED_BLOCK":
17
+ return "Unsigned"
18
+ else:
19
+ return "Unknown"
20
+
21
+
22
+ def execution_status(
23
+ images: List[Any], show: bool = False
24
+ ) -> (int, str, List[Any], List[Any]):
25
+ if isinstance(images, list):
26
+ labels = []
27
+ boxes = []
28
+ crops = []
29
+ for page in images:
30
+ model = SignatureBlockModel(page)
31
+ if model.predictions[0]["boxes"].shape[0] > 0:
32
+ page_labels = model._get_labels_names()
33
+ labels.append(page_labels)
34
+ boxes.extend(model.get_box_crops())
35
+ crops.extend(model.get_boxes())
36
+ if show:
37
+ boxes = model.show_boxes()
38
+ # page.close()
39
+ num_sig_pages = len(labels)
40
+ execution_status = agreement_status(flatten_list(labels))
41
+ return num_sig_pages, execution_status, boxes, crops
42
+ else:
43
+ return None, None, None, None
44
+
45
+
46
+ if __name__ == "__main__":
47
+ from gabriel.parsers.pdf_parser import ParsePDF
48
+
49
+ filepath = "/Users/jordandavis/GitHub/gabriel/gabriel/datasets/MASTER_REVIEWED/SIGNATURE_PAGE/1a90afa457f328fc7f560d9b49af7b8f.pdf"
50
+ image = list(ParsePDF(filepath).yield_image())[0]
51
+
52
+ num_sig_pages, status, boxes, crops = execution_status(image)
53
+ print(f"Num Sig Pages: {num_sig_pages}")
54
+ print(f"Status: {status}")
55
+ print(f"Boxes: {boxes}")
56
+ print(f"Crops: {crops}")
scripts/signature_blocks.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.models as models
7
+ from PIL import Image
8
+ from torchvision import models
9
+ from torchvision import transforms as T
10
+ from torchvision.ops import nms
11
+ from typing import List, Any, Tuple
12
+
13
+ STATE_DICT = os.path.join(
14
+ os.path.dirname(__file__), "..", "state_dicts", "signature_blocks_v14.pth"
15
+ )
16
+
17
+
18
+ def get_device():
19
+ if torch.cuda.is_available():
20
+ device = "cuda"
21
+
22
+ # aten::hardsigmoid.out' is not currently implemented for the MPS device
23
+ # setting fallback does not work either
24
+ # elif torch.backends.mps.is_built():
25
+ # device = "mps"
26
+ else:
27
+ device = "cpu"
28
+ return device
29
+
30
+
31
+ class ImgFactory:
32
+ def serialize(self, img: Any) -> Any:
33
+ serializer = self._get_serializer(img)
34
+ return serializer(img)
35
+
36
+ def _get_serializer(self, img: Any) -> Any:
37
+ if isinstance(img, str):
38
+ return self._serialize_string_to_image
39
+ else:
40
+ return self._serialize_image_to_image
41
+
42
+ def _serialize_string_to_image(self, img):
43
+ return Image.open(img)
44
+
45
+ def _serialize_image_to_image(self, img):
46
+ return img
47
+
48
+
49
+ class SignatureBlockModel(ImgFactory):
50
+ def __init__(self, img, state_dict_path=STATE_DICT):
51
+ self.state_dict_path = state_dict_path
52
+ self.classes = {0: "NOTHING", 1: "SIGNED_BLOCK", 2: "UNSIGNED_BLOCK"}
53
+ self.n_classes = len(self.classes)
54
+ self.device = get_device()
55
+ self.model = self._load_model()
56
+ self.img = self.serialize(img)
57
+
58
+ with torch.no_grad():
59
+ self.model.eval()
60
+ self.predictions = self._get_prediction()
61
+
62
+ def _load_model(self):
63
+ weights = models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT
64
+ model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=weights)
65
+ # change the head
66
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
67
+
68
+ model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(
69
+ in_features, self.n_classes
70
+ )
71
+
72
+ model.load_state_dict(
73
+ torch.load(self.state_dict_path, map_location=self.device)
74
+ )
75
+
76
+ return model.to(self.device)
77
+
78
+ def filter_overlap(self, predictions, iou_threshold=0.3):
79
+ boxes = predictions[0]["boxes"]
80
+ scores = predictions[0]["scores"]
81
+ nms_filter = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
82
+ return nms_filter
83
+
84
+ def filter_scores(self, predictions, score_thrs=0.94):
85
+ nms_filter = self.filter_overlap(predictions)
86
+ boxes = predictions[0]["boxes"]
87
+ scores = predictions[0]["scores"]
88
+ labels = predictions[0]["labels"]
89
+
90
+ score_filter = scores[nms_filter] > score_thrs
91
+ boxes = boxes[nms_filter][score_filter]
92
+ scores = scores[nms_filter][score_filter]
93
+ labels = labels[nms_filter][score_filter]
94
+ return boxes, scores, labels
95
+
96
+ def _get_prediction(self):
97
+ transform = T.Compose([T.ToTensor()])
98
+ img = transform(self.img)
99
+ img = img.to(self.device)
100
+ predictions = self.model([img])
101
+ boxes, scores, labels = self.filter_scores(predictions)
102
+ return [{"boxes": boxes, "scores": scores, "labels": labels}]
103
+
104
+ def get_boxes(self):
105
+ pred = self._get_prediction()
106
+ boxes = pred[0]["boxes"].cpu().detach().numpy()
107
+ int_boxes = []
108
+ for box in boxes:
109
+ box = [int(x) for x in box]
110
+ int_boxes.append(box)
111
+ return int_boxes
112
+
113
+ def get_scores(self):
114
+ pred = self._get_prediction()
115
+ scores = pred[0]["scores"].cpu().detach().numpy()
116
+ return scores
117
+
118
+ def get_labels(self):
119
+ pred = self._get_prediction()
120
+ labels = pred[0]["labels"].cpu().detach().numpy()
121
+ return labels
122
+
123
+ def get_labels_names(self):
124
+ pred = self._get_prediction()
125
+ labels = pred[0]["labels"].cpu().detach().numpy()
126
+ label_names = [self.classes[label] for label in labels]
127
+ return label_names
128
+
129
+ def _get_prediction_dict(self):
130
+ boxes = self.get_boxes()
131
+ scores = self.get_scores()
132
+ labels = self.get_labels()
133
+ return {"boxes": boxes, "scores": scores, "labels": labels}
134
+
135
+ def _signature_crops(self, show=True):
136
+ boxes = self.get_boxes()
137
+ scores = self.get_scores()
138
+ labels = self.get_labels()
139
+ signature_crops = []
140
+ for box, label, score in tuple(zip(boxes, labels, scores)):
141
+ crop = self.extract_box(box)
142
+ if show:
143
+ crop = plt.imshow(crop)
144
+ signature_crops.append(crop)
145
+ return signature_crops
146
+
147
+ def get_prediction(self):
148
+ return self._get_prediction_dict()
149
+
150
+ def get_image(self):
151
+ return self.img
152
+
153
+ def get_image_array(self):
154
+ return np.array(self.img)
155
+
156
+ def get_box_crops(self):
157
+ boxes = self.get_boxes()
158
+ box_crops = []
159
+ for box in boxes:
160
+ crop = self.img.crop(box)
161
+ box_crops.append(crop)
162
+ return box_crops
163
+
164
+ def extract_box(self, box):
165
+ xmin, ymin, xmax, ymax = box
166
+ image = np.array(self.img)
167
+ return image[ymin:ymax, xmin:xmax]
168
+
169
+ def show_boxes(self):
170
+ boxes = self.get_boxes()
171
+ scores = self.get_scores()
172
+ labels = self.get_labels()
173
+ box_crops = []
174
+ for box, label, score in tuple(zip(boxes, labels, scores)):
175
+ print(f"Status: {self.classes[label]}")
176
+ print(f"Score: {score}")
177
+ crop = self.extract_box(box)
178
+ plt.imshow(crop)
179
+ plt.show()
180
+ plt.close()
181
+ box_crops.append(crop)
182
+ return box_crops
183
+
184
+ def draw_boxes(self):
185
+ img = np.array(self.img)
186
+ boxes = self.get_boxes()
187
+ labels = self.get_labels()
188
+ thickness = 2
189
+ overlay = img.copy()
190
+ for box, label in zip(boxes, labels):
191
+ box = [int(x) for x in box]
192
+ if label == 2:
193
+ color = (0, 0, 255) # red
194
+ elif label == 1:
195
+ color = (0, 255, 0) # green
196
+ cv2.rectangle(
197
+ overlay, (box[0], box[1]), (box[2], box[3]), color, -1
198
+ ) # Filled rectangle
199
+
200
+ alpha = 0.4 # Transparency factor
201
+ image_boxes = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0)
202
+
203
+ # Draw box outlines
204
+ for box, label in zip(boxes, labels):
205
+ box = [int(x) for x in box]
206
+ if label == 2:
207
+ color = (0, 0, 255) # red
208
+ elif label == 1:
209
+ color = (0, 255, 0) # green
210
+ cv2.rectangle(
211
+ image_boxes, (box[0], box[1]), (box[2], box[3]), color, thickness
212
+ )
213
+
214
+ return Image.fromarray(cv2.cvtColor(image_boxes, cv2.COLOR_BGR2RGB))
state_dicts/signature_blocks_v14.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8de7450d624842a805cb011db1a8bdd3359a817a4f7e5b4c8bcdaf9e340423b0
3
+ size 76042575