File size: 3,694 Bytes
52cbb9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import cv2
import ot
from PIL import Image

def transfer_channel(source_channel, target_channel):
    source_hist, _ = np.histogram(source_channel, bins=256, range=(0, 256))
    target_hist, _ = np.histogram(target_channel, bins=256, range=(0, 256))

    source_hist = source_hist.astype(np.float64) / source_hist.sum()
    target_hist = target_hist.astype(np.float64) / target_hist.sum()

    r = np.arange(256).reshape((-1, 1))
    c = np.arange(256).reshape((1, -1))
    M = (r - c) ** 2
    M = M / M.max()

    P = ot.emd(source_hist, target_hist, M)
    
    transferred_channel = np.zeros_like(source_channel)
    for i in range(256):
        transferred_channel[source_channel == i] = P[i].argmax()

    return transferred_channel

def optimal_transport_color_transfer(source, target):
    source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2Lab).astype(np.float64)
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2Lab).astype(np.float64)

    # Transfer all channels L, a, and b
    for ch in range(3):
        source_lab[:, :, ch] = transfer_channel(source_lab[:, :, ch], target_lab[:, :, ch])

    transferred_rgb = cv2.cvtColor(source_lab.astype(np.uint8), cv2.COLOR_Lab2BGR)
    return transferred_rgb
    

def rgb_to_hex(rgb):
    return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])

def hex_to_rgb(hex):
    hex = hex.lstrip('#')
    return tuple(int(hex[i:i+2], 16) for i in (0, 2, 4))


def create_color_palette(colors, palette_width=800, palette_height=200):
    """
    Receives a list of colors in hex format and creates a palette image
    """
    pixels = []
    n_colors = len(colors)
    for i in range(n_colors):
        color = hex_to_rgb(colors[i])
        for j in range(palette_width//n_colors * palette_height):
            pixels.append(color)
    img = Image.new('RGB', (palette_height, palette_width))
    img.putdata(pixels) 
    # img.show()
    return img


# if __name__ == "__main__":
#     source = cv2.imread("estampa-test.png")
#     target = cv2.imread("color.png")
#     transferred = optimal_transport_color_transfer(source, target)

#     smooth = False#test with true to have a different result.
#     if smooth:
#         # Apply bilateral filtering
#         diameter = 30  # diameter of each pixel neighborhood, adjust based on your image size
#         sigma_color = 25  # larger value means colors farther to each other will mix together
#         sigma_space = 25  # larger values means farther pixels will influence each other if their colors are close enough
#         smoothed = cv2.bilateralFilter(transferred, diameter, sigma_color, sigma_space)
#         cv2.imwrite("result_OTA.jpg", smoothed)
#     else:
#         cv2.imwrite("result_OTA.jpg", transferred)


def recolor(source, colors):
    pallete_img = create_color_palette(colors)
    palette_bgr = cv2.cvtColor(np.array(pallete_img), cv2.COLOR_RGB2BGR)
    recolored = optimal_transport_color_transfer(source, palette_bgr)
    smooth = True#test with true for different results.
    if smooth:
        # Apply bilateral filtering
        diameter = 10  # diameter of each pixel neighborhood, adjust based on your image size
        sigma_color = 25  # larger value means colors farther to each other will mix together
        sigma_space = 15  # larger values means farther pixels will influence each other if their colors are close enough
        smoothed = cv2.bilateralFilter(recolored, diameter, sigma_color, sigma_space)
        recoloredFile = cv2.imwrite("result.jpg", smoothed, [cv2.IMWRITE_JPEG_QUALITY, 100])
        return recoloredFile
    else:
        recoloredFile = cv2.imwrite("result.jpg", recolored)
        return recoloredFile