Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from sklearn.cluster import KMeans | |
def _image_resize(image: Image.Image, pixels: int = 90000, **kwargs): | |
rt = (image.size[0] * image.size[1] / pixels) ** 0.5 | |
if rt > 1.0: | |
small_image = image.resize((int(image.size[0] / rt), int(image.size[1] / rt)), **kwargs) | |
else: | |
small_image = image.copy() | |
return small_image | |
def get_main_colors(image: Image.Image, n: int = 28, pixels: int = 90000) -> Image.Image: | |
image = image.copy() | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
small_image = _image_resize(image, pixels) | |
few_raw = np.asarray(small_image).reshape(-1, 3) | |
kmeans = KMeans(n_clusters=n) | |
kmeans.fit(few_raw) | |
width, height = image.size | |
raw = np.asarray(image).reshape(-1, 3) | |
new_data = kmeans.cluster_centers_[kmeans.predict(raw)] | |
new_data = new_data.round().astype(np.uint8).reshape((height, width, 3)) | |
return Image.fromarray(new_data, mode='RGB') | |
def main_func(image: Image.Image, n: int, pixels: int, fixed_width: bool, width: int): | |
new_image = get_main_colors(image, n, pixels) | |
if fixed_width: | |
_width, _height = new_image.size | |
r = width / _width | |
new_width, new_height = int(round(_width * r)), int(round(_height * r)) | |
new_image = new_image.resize((new_width, new_height), resample=Image.NEAREST) | |
return new_image | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
ch_image = gr.Image(type='pil', label='Original Image') | |
with gr.Row(): | |
ch_clusters = gr.Slider(value=8, minimum=2, maximum=256, step=2, label='Clusters') | |
ch_pixels = gr.Slider(value=100000, minimum=10000, maximum=1000000, step=10000, | |
label='Pixels for Clustering') | |
ch_fixed_width = gr.Checkbox(value=True, label='Width Fixed') | |
ch_width = gr.Slider(value=200, minimum=12, maximum=2048, label='Width') | |
ch_submit = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
ch_output = gr.Image(type='pil', label='Output Image') | |
ch_submit.click( | |
main_func, | |
inputs=[ch_image, ch_clusters, ch_pixels, ch_fixed_width, ch_width], | |
outputs=[ch_output], | |
) | |
demo.queue(os.cpu_count()).launch() | |