File size: 2,891 Bytes
a3a3ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
# MLSD Line Detection
# From https://github.com/navervision/mlsd
# Apache-2.0 license

import warnings
from abc import ABCMeta

import cv2
import numpy as np
import torch
from PIL import Image

from scepter.modules.annotator.base_annotator import BaseAnnotator
from scepter.modules.annotator.mlsd.mbv2_mlsd_large import MobileV2_MLSD_Large
from scepter.modules.annotator.mlsd.utils import pred_lines
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.annotator.utils import resize_image, resize_image_ori
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS


@ANNOTATORS.register_class()
class MLSDdetector(BaseAnnotator, metaclass=ABCMeta):
    def __init__(self, cfg, logger=None):
        super().__init__(cfg, logger=logger)
        model = MobileV2_MLSD_Large()
        pretrained_model = cfg.get('PRETRAINED_MODEL', None)
        if pretrained_model:
            with FS.get_from(pretrained_model, wait_finish=True) as local_path:
                model.load_state_dict(torch.load(local_path), strict=True)
        self.model = model.eval()
        self.thr_v = cfg.get('THR_V', 0.1)
        self.thr_d = cfg.get('THR_D', 0.1)

    @torch.no_grad()
    @torch.inference_mode()
    @torch.autocast('cuda', enabled=False)
    def forward(self, image):
        if isinstance(image, Image.Image):
            image = np.array(image)
        elif isinstance(image, torch.Tensor):
            image = image.detach().cpu().numpy()
        elif isinstance(image, np.ndarray):
            image = image.copy()
        else:
            raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
        h, w, c = image.shape
        image, k = resize_image(image, 1024 if min(h, w) > 1024 else min(h, w))
        img_output = np.zeros_like(image)
        try:
            lines = pred_lines(image,
                               self.model, [image.shape[0], image.shape[1]],
                               self.thr_v,
                               self.thr_d,
                               device=we.device_id)
            for line in lines:
                x_start, y_start, x_end, y_end = [int(val) for val in line]
                cv2.line(img_output, (x_start, y_start), (x_end, y_end),
                         [255, 255, 255], 1)
        except Exception as e:
            warnings.warn(f'{e}')
            return None
        img_output = resize_image_ori(h, w, img_output, k)
        return img_output[:, :, 0]

    @staticmethod
    def get_config_template():
        return dict_to_yaml('ANNOTATORS',
                            __class__.__name__,
                            MLSDdetector.para_dict,
                            set_name=True)