File size: 2,459 Bytes
5ebeb73
 
 
 
 
 
 
 
 
 
 
 
9429105
5ebeb73
 
 
c60ebd1
5ebeb73
 
 
 
c60ebd1
 
 
 
 
 
 
 
 
5ebeb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417b347
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os

import torch
from huggingface_hub import snapshot_download
from mmdet.apis import DetInferencer

# from mmengine import Config
from mmocr.apis import TextRecInferencer


class HtrModels:
    def __init__(self, local_run=False):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model_folder = "./models"
        self.region_config = f"{model_folder}/RmtDet_regions/rtmdet_m_textregions_2_concat.py"
        self.region_checkpoint = f"{model_folder}/RmtDet_regions/epoch_12.pth"

        self.line_config = f"{model_folder}/RmtDet_lines/rtmdet_m_textlines_2_concat.py"
        self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth"

        self.mmocr_config = f"{model_folder}/SATRN/_base_satrn_shallow_concat.py"
        self.mmocr_checkpoint = f"{model_folder}/SATRN/epoch_5.pth"

        # Check if model files exist at the specified paths, if not, get the config
        if not (
            os.path.exists(self.region_checkpoint)
            and os.path.exists(self.line_checkpoint)
            and os.path.exists(self.mmocr_checkpoint)
        ):
            config_path = self.get_config()
            self.region_checkpoint = config_path["region_checkpoint"]
            self.line_checkpoint = config_path["line_checkpoint"]
            self.mmocr_checkpoint = config_path["mmocr_checkpoint"]

    def load_region_model(self):
        # build the model from a config file and a checkpoint file
        return DetInferencer(self.region_config, self.region_checkpoint, device=self.device)

    def load_line_model(self):
        return DetInferencer(self.line_config, self.line_checkpoint, device=self.device)

    def load_htr_model(self):
        inferencer = TextRecInferencer(self.mmocr_config, self.mmocr_checkpoint, device=self.device)
        return inferencer

    @staticmethod
    def get_config():
        path_models = snapshot_download(
            "Riksarkivet/HTR_pipeline_models",
            allow_patterns=["*.pth"],
            token="__INSERT__FINS_HUGGINFACE_TOKEN__",
            cache_dir="./",
        )
        config_path = {
            "region_checkpoint": os.path.join(path_models, "RmtDet_regions/epoch_12.pth"),
            "line_checkpoint": os.path.join(path_models, "RmtDet_lines/epoch_12.pth"),
            "mmocr_checkpoint": os.path.join(path_models, "SATRN/epoch_5.pth"),
        }

        return config_path


if __name__ == "__main__":
    pass