Spaces:
Running
Running
import os | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torchvision.models as models | |
from PIL import Image | |
from torchvision import models | |
from torchvision import transforms as T | |
from torchvision.ops import nms | |
from typing import List, Any, Tuple | |
STATE_DICT = os.path.join( | |
os.path.dirname(__file__), "..", "state_dicts", "signature_blocks_v14.pth" | |
) | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
# aten::hardsigmoid.out' is not currently implemented for the MPS device | |
# setting fallback does not work either | |
# elif torch.backends.mps.is_built(): | |
# device = "mps" | |
else: | |
device = "cpu" | |
return device | |
class ImgFactory: | |
def serialize(self, img: Any) -> Any: | |
serializer = self._get_serializer(img) | |
return serializer(img) | |
def _get_serializer(self, img: Any) -> Any: | |
if isinstance(img, str): | |
return self._serialize_string_to_image | |
else: | |
return self._serialize_image_to_image | |
def _serialize_string_to_image(self, img): | |
return Image.open(img) | |
def _serialize_image_to_image(self, img): | |
return img | |
class SignatureBlockModel(ImgFactory): | |
def __init__(self, img, state_dict_path=STATE_DICT): | |
self.state_dict_path = state_dict_path | |
self.classes = {0: "NOTHING", 1: "SIGNED_BLOCK", 2: "UNSIGNED_BLOCK"} | |
self.n_classes = len(self.classes) | |
self.device = get_device() | |
self.model = self._load_model() | |
self.img = self.serialize(img) | |
with torch.no_grad(): | |
self.model.eval() | |
self.predictions = self._get_prediction() | |
def _load_model(self): | |
weights = models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT | |
model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=weights) | |
# change the head | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor( | |
in_features, self.n_classes | |
) | |
model.load_state_dict( | |
torch.load(self.state_dict_path, map_location=self.device) | |
) | |
return model.to(self.device) | |
def filter_overlap(self, predictions, iou_threshold=0.3): | |
boxes = predictions[0]["boxes"] | |
scores = predictions[0]["scores"] | |
nms_filter = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold) | |
return nms_filter | |
def filter_scores(self, predictions, score_thrs=0.94): | |
nms_filter = self.filter_overlap(predictions) | |
boxes = predictions[0]["boxes"] | |
scores = predictions[0]["scores"] | |
labels = predictions[0]["labels"] | |
score_filter = scores[nms_filter] > score_thrs | |
boxes = boxes[nms_filter][score_filter] | |
scores = scores[nms_filter][score_filter] | |
labels = labels[nms_filter][score_filter] | |
return boxes, scores, labels | |
def _get_prediction(self): | |
transform = T.Compose([T.ToTensor()]) | |
img = transform(self.img) | |
img = img.to(self.device) | |
predictions = self.model([img]) | |
boxes, scores, labels = self.filter_scores(predictions) | |
return [{"boxes": boxes, "scores": scores, "labels": labels}] | |
def get_boxes(self): | |
pred = self._get_prediction() | |
boxes = pred[0]["boxes"].cpu().detach().numpy() | |
int_boxes = [] | |
for box in boxes: | |
box = [int(x) for x in box] | |
int_boxes.append(box) | |
return int_boxes | |
def get_scores(self): | |
pred = self._get_prediction() | |
scores = pred[0]["scores"].cpu().detach().numpy() | |
return scores | |
def get_labels(self): | |
pred = self._get_prediction() | |
labels = pred[0]["labels"].cpu().detach().numpy() | |
return labels | |
def get_labels_names(self): | |
pred = self._get_prediction() | |
labels = pred[0]["labels"].cpu().detach().numpy() | |
label_names = [self.classes[label] for label in labels] | |
return label_names | |
def _get_prediction_dict(self): | |
boxes = self.get_boxes() | |
scores = self.get_scores() | |
labels = self.get_labels() | |
return {"boxes": boxes, "scores": scores, "labels": labels} | |
def _signature_crops(self, show=True): | |
boxes = self.get_boxes() | |
scores = self.get_scores() | |
labels = self.get_labels() | |
signature_crops = [] | |
for box, label, score in tuple(zip(boxes, labels, scores)): | |
crop = self.extract_box(box) | |
if show: | |
crop = plt.imshow(crop) | |
signature_crops.append(crop) | |
return signature_crops | |
def get_prediction(self): | |
return self._get_prediction_dict() | |
def get_image(self): | |
return self.img | |
def get_image_array(self): | |
return np.array(self.img) | |
def get_box_crops(self): | |
boxes = self.get_boxes() | |
box_crops = [] | |
for box in boxes: | |
crop = self.img.crop(box) | |
box_crops.append(crop) | |
return box_crops | |
def extract_box(self, box): | |
xmin, ymin, xmax, ymax = box | |
image = np.array(self.img) | |
return image[ymin:ymax, xmin:xmax] | |
def show_boxes(self): | |
boxes = self.get_boxes() | |
scores = self.get_scores() | |
labels = self.get_labels() | |
box_crops = [] | |
for box, label, score in tuple(zip(boxes, labels, scores)): | |
print(f"Status: {self.classes[label]}") | |
print(f"Score: {score}") | |
crop = self.extract_box(box) | |
plt.imshow(crop) | |
plt.show() | |
plt.close() | |
box_crops.append(crop) | |
return box_crops | |
def draw_boxes(self): | |
img = np.array(self.img) | |
boxes = self.get_boxes() | |
labels = self.get_labels() | |
thickness = 2 | |
overlay = img.copy() | |
for box, label in zip(boxes, labels): | |
box = [int(x) for x in box] | |
if label == 2: | |
color = (0, 0, 255) # red | |
elif label == 1: | |
color = (0, 255, 0) # green | |
cv2.rectangle( | |
overlay, (box[0], box[1]), (box[2], box[3]), color, -1 | |
) # Filled rectangle | |
alpha = 0.4 # Transparency factor | |
image_boxes = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) | |
# Draw box outlines | |
for box, label in zip(boxes, labels): | |
box = [int(x) for x in box] | |
if label == 2: | |
color = (0, 0, 255) # red | |
elif label == 1: | |
color = (0, 255, 0) # green | |
cv2.rectangle( | |
image_boxes, (box[0], box[1]), (box[2], box[3]), color, thickness | |
) | |
return Image.fromarray(cv2.cvtColor(image_boxes, cv2.COLOR_BGR2RGB)) | |