File size: 1,526 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pytorch_lightning as pl
from src.ss.det_models.model import POIDetection
from src.ss.datasets_signboard_detection.datamodule import POIDataModule
from src.ss.det_models.inference_signboard_detection import POIDetectionTask


def load_model(checkpoint_path):
    model = POIDetection.load_from_checkpoint(checkpoint_path=checkpoint_path)
    return model


def inference_signboard(image_path, checkpoint, score):

    dm = POIDataModule(data_path=image_path,
                       seed=42)
    dm.setup("predict")

    model = load_model(checkpoint)
    from src.ss.det_models.inference_signboard_detection import POIDetectionTask
    task = POIDetectionTask(model,
                            data_path=image_path,
                            score=score)

    # accelerator='gpu', devices=1
    trainer = pl.Trainer(gpus=1,
                         max_epochs=-1)
    trainer.predict(task, datamodule=dm)
    return task.output


class SignBoardDetector():
    def __init__(self,
                 checkpoint) -> None:
        self.model = POIDetection.load_from_checkpoint(
            checkpoint_path=checkpoint)

    def inference_signboard(self, image, score):
        dm = POIDataModule(data_path=image,
                           seed=42)
        dm.setup("predict")

        task = POIDetectionTask(self.model,
                                score=score)

        trainer = pl.Trainer(gpus=1,
                             max_epochs=-1)
        trainer.predict(task, datamodule=dm)
        return task.output