TextExtractor / backend /pytorch.py
deepsh2207's picture
Updates in readme, parameters
d812fea
raw
history blame
2.22 kB
import numpy as np
import torch
from doctr.models import ocr_predictor
from doctr.models.predictor import OCRPredictor
DET_ARCHS = [
"db_resnet50",
"db_resnet34",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
]
RECO_ARCHS = [
"crnn_vgg16_bn",
"crnn_mobilenet_v3_small",
"crnn_mobilenet_v3_large",
"master",
"sar_resnet31",
"vitstr_small",
"vitstr_base",
"parseq",
]
def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
device: torch.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Args:
----
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
device: torch.device, the device to load the predictor on
Returns:
-------
instance of OCRPredictor
"""
predictor = ocr_predictor(
det_arch,
reco_arch,
pretrained=True,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=straighten_pages,
detect_orientation=not assume_straight_pages,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
return predictor
# def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
# """Forward an image through the predictor
# Args:
# ----
# predictor: instance of OCRPredictor
# image: image to process
# device: torch.device, the device to process the image on
# Returns:
# -------
# segmentation map
# """
# with torch.no_grad():
# processed_batches = predictor.det_predictor.pre_processor([image])
# out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
# seg_map = out["out_map"].to("cpu").numpy()
# return seg_map