File size: 3,503 Bytes
a3a3ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np


def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(
        input_image, (W, H),
        interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img, k


def resize_image_ori(h, w, image, k):
    img = cv2.resize(
        image, (w, h),
        interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img


class AnnotatorProcessor():
    canny_cfg = {
        'NAME': 'CannyAnnotator',
        'LOW_THRESHOLD': 100,
        'HIGH_THRESHOLD': 200,
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['canny']
    }
    hed_cfg = {
        'NAME': 'HedAnnotator',
        'PRETRAINED_MODEL':
        'ms://damo/scepter_scedit@annotator/ckpts/ControlNetHED.pth',
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['hed']
    }
    openpose_cfg = {
        'NAME': 'OpenposeAnnotator',
        'BODY_MODEL_PATH':
        'ms://damo/scepter_scedit@annotator/ckpts/body_pose_model.pth',
        'HAND_MODEL_PATH':
        'ms://damo/scepter_scedit@annotator/ckpts/hand_pose_model.pth',
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['openpose']
    }
    midas_cfg = {
        'NAME': 'MidasDetector',
        'PRETRAINED_MODEL':
        'ms://damo/scepter_scedit@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt',
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['depth']
    }
    mlsd_cfg = {
        'NAME': 'MLSDdetector',
        'PRETRAINED_MODEL':
        'ms://damo/scepter_scedit@annotator/ckpts/mlsd_large_512_fp32.pth',
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['mlsd']
    }
    color_cfg = {
        'NAME': 'ColorAnnotator',
        'RATIO': 64,
        'INPUT_KEYS': ['img'],
        'OUTPUT_KEYS': ['color']
    }

    anno_type_map = {
        'canny': canny_cfg,
        'hed': hed_cfg,
        'pose': openpose_cfg,
        'depth': midas_cfg,
        'mlsd': mlsd_cfg,
        'color': color_cfg
    }

    def __init__(self, anno_type):
        from scepter.modules.annotator.registry import ANNOTATORS
        from scepter.modules.utils.config import Config
        from scepter.modules.utils.distribute import we

        if isinstance(anno_type, str):
            assert anno_type in self.anno_type_map.keys()
            anno_type = [anno_type]
        elif isinstance(anno_type, (list, tuple)):
            assert all(tp in self.anno_type_map.keys() for tp in anno_type)
        else:
            raise Exception(f'Error anno_type: {anno_type}')

        general_dict = {
            'NAME': 'GeneralAnnotator',
            'ANNOTATORS': [self.anno_type_map[tp] for tp in anno_type]
        }
        general_anno = Config(cfg_dict=general_dict, load=False)
        self.general_ins = ANNOTATORS.build(general_anno).to(we.device_id)

    def run(self, image, anno_type=None):
        output_image = self.general_ins({'img': image})
        if anno_type is not None:
            if isinstance(anno_type, str) and anno_type in output_image:
                return output_image[anno_type]
            else:
                return {
                    tp: output_image[tp]
                    for tp in anno_type if tp in output_image
                }
        else:
            return output_image