File size: 3,626 Bytes
b741ff1
 
dba785d
 
42c5e09
dba785d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42c5e09
 
 
 
dba785d
42c5e09
dba785d
 
 
 
 
 
 
 
 
 
42c5e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba785d
 
 
42c5e09
dba785d
 
 
 
 
 
 
 
 
 
 
 
 
 
42c5e09
 
 
 
 
 
 
 
dba785d
 
 
 
42c5e09
dba785d
 
b741ff1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os

import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.cluster import KMeans


def _image_resize(image: Image.Image, pixels: int = 90000, **kwargs):
    rt = (image.size[0] * image.size[1] / pixels) ** 0.5
    if rt > 1.0:
        small_image = image.resize((int(image.size[0] / rt), int(image.size[1] / rt)), **kwargs)
    else:
        small_image = image.copy()
    return small_image


def get_main_colors(image: Image.Image, n: int = 28, pixels: int = 90000) -> Image.Image:
    image = image.copy()
    if image.mode != 'RGB':
        image = image.convert('RGB')
    small_image = _image_resize(image, pixels)

    few_raw = np.asarray(small_image).reshape(-1, 3)
    kmeans = KMeans(n_clusters=n)
    kmeans.fit(few_raw)

    width, height = image.size
    raw = np.asarray(image).reshape(-1, 3)
    colors = kmeans.cluster_centers_.round().astype(np.uint8)
    prediction = kmeans.predict(raw)
    new_data = colors[prediction].reshape((height, width, 3))
    new_image = Image.fromarray(new_data, mode='RGB')

    return new_image


def main_func(image: Image.Image, n: int, pixels: int, fixed_width: bool, width: int):
    new_image = get_main_colors(image, n, pixels)
    if fixed_width:
        _width, _height = new_image.size
        r = width / _width
        new_width, new_height = int(round(_width * r)), int(round(_height * r))
        new_image = new_image.resize((new_width, new_height), resample=Image.NEAREST)

    df = pd.DataFrame(data=np.asarray(new_image).reshape(-1, 3), columns=['r', 'g', 'b'])
    df['id_'] = 1
    table = df.groupby(['r', 'g', 'b'])['id_'].agg(['count']).reset_index().sort_values('count', ascending=False)
    table['ratio'] = table['count'] / table['count'].sum()
    hexes = []
    for r, g, b in zip(table['r'], table['g'], table['b']):
        hexes.append(f'#{r:02x}{g:02x}{b:02x}')
    table['hex'] = hexes

    new_table = pd.DataFrame({
        'Hex': table['hex'],
        'Pixels': table['count'],
        'Ratio': table['ratio'],
        'Red': table['r'],
        'Green': table['g'],
        'Blue': table['b'],
    })

    return new_image, new_table


if __name__ == '__main__':
    pd.set_option("display.precision", 3)
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                ch_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    ch_clusters = gr.Slider(value=8, minimum=2, maximum=256, step=2, label='Clusters')
                    ch_pixels = gr.Slider(value=100000, minimum=10000, maximum=1000000, step=10000,
                                          label='Pixels for Clustering')
                    ch_fixed_width = gr.Checkbox(value=True, label='Width Fixed')
                    ch_width = gr.Slider(value=200, minimum=12, maximum=2048, label='Width')

                ch_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                with gr.Tabs():
                    with gr.Tab('Output Image'):
                        ch_output = gr.Image(type='pil', label='Output Image')
                    with gr.Tab('Color Map'):
                        ch_color_map = gr.Dataframe(
                            headers=['Hex', 'Pixels', 'Ratio', 'Red', 'Green', 'Blue'],
                            label='Color Map'
                        )

        ch_submit.click(
            main_func,
            inputs=[ch_image, ch_clusters, ch_pixels, ch_fixed_width, ch_width],
            outputs=[ch_output, ch_color_map],
        )

    demo.queue(os.cpu_count()).launch()