File size: 2,834 Bytes
f1f9265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import random

import cv2
import numpy as np
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode

from tools.controlnet.annotator.hed import HEDdetector
from tools.controlnet.annotator.util import HWC3, nms, resize_image

preprocessor = None


def transform_control_signal(control_signal, hw):
    if isinstance(control_signal, str):
        control_signal = Image.open(control_signal)
    elif isinstance(control_signal, Image.Image):
        control_signal = control_signal
    elif isinstance(control_signal, np.ndarray):
        control_signal = Image.fromarray(control_signal)
    else:
        raise ValueError("control_signal must be a path or a PIL.Image.Image or a numpy array")

    transform = T.Compose(
        [
            T.Lambda(lambda img: img.convert("RGB")),
            T.Resize((int(hw[0, 0]), int(hw[0, 1])), interpolation=InterpolationMode.BICUBIC),  # Image.BICUBIC
            T.CenterCrop((int(hw[0, 0]), int(hw[0, 1]))),
            T.ToTensor(),
            T.Normalize([0.5], [0.5]),
        ]
    )
    return transform(control_signal).unsqueeze(0)


def get_scribble_map(input_image, det, detect_resolution=512, thickness=None):
    """
    Generate scribble map from input image

    Args:
        input_image: Input image (numpy array, HWC format)
        det: Detector type ('Scribble_HED', 'Scribble_PIDI', 'None')
        detect_resolution: Processing resolution
        thickness: Line thickness (between 0-24, None for random)

    Returns:
        Processed scribble map
    """
    global preprocessor

    # Initialize detector
    if "HED" in det and not isinstance(preprocessor, HEDdetector):
        preprocessor = HEDdetector()

    input_image = HWC3(input_image)

    if det == "None":
        detected_map = input_image.copy()
    else:
        # Generate scribble map
        detected_map = preprocessor(resize_image(input_image, detect_resolution))
        detected_map = HWC3(detected_map)

        # Post-processing
        detected_map = nms(detected_map, 127, 3.0)
        detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
        detected_map[detected_map > 4] = 255
        detected_map[detected_map < 255] = 0

        # Control line thickness
        if thickness is None:
            thickness = random.randint(0, 24)  # Random thickness, including 0
        if thickness == 0:
            # Use erosion operation to get thinner lines
            kernel = np.ones((4, 4), np.uint8)
            detected_map = cv2.erode(detected_map, kernel, iterations=1)
        elif thickness > 1:
            kernel_size = thickness // 2
            kernel = np.ones((kernel_size, kernel_size), np.uint8)
            detected_map = cv2.dilate(detected_map, kernel, iterations=1)

    return detected_map