Spaces:
Runtime error
Runtime error
import numpy as np | |
import cv2 | |
from PIL import Image | |
from skimage.color import rgb2lab | |
from skimage.color import lab2rgb | |
from sklearn.cluster import KMeans | |
def count_high_freq_colors(image): | |
im = image.getcolors(maxcolors=1024*1024) | |
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) | |
freqs = [c[0] for c in sorted_colors] | |
mean_freq = sum(freqs) / len(freqs) | |
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq*1.25)] | |
return high_freq_colors | |
def get_high_freq_colors(image, similarity_threshold=30): | |
image_copy = image.copy() | |
high_freq_colors = count_high_freq_colors(image) | |
# Check for similar colors and replace the lower frequency color with the higher frequency color in the image | |
for i, (freq1, color1) in enumerate(high_freq_colors): | |
for j, (freq2, color2) in enumerate(high_freq_colors): | |
if (color_distance(color1, color2) < similarity_threshold) or (color_distance(color1, opaque_color_on_white(color2, 0.5)) < 5): | |
if(freq2 > freq1): | |
replace_color(image_copy, color1, color2) | |
high_freq_colors = count_high_freq_colors(image_copy) | |
print(high_freq_colors) | |
return [high_freq_colors, image_copy] | |
def color_quantization(image, color_frequency_list): | |
# Extract the color values from the frequency list | |
color_values = [color for _, color in color_frequency_list] | |
# Replace the colors that are not in the frequency list with white | |
mask = np.ones(image.shape[:2], dtype=bool) | |
for color in color_values: | |
color_mask = np.all(image == color, axis=2) | |
mask = np.logical_and(mask, np.logical_not(color_mask)) | |
image[mask] = (255, 255, 255) | |
return image | |
def create_binary_matrix(img_arr, target_color): | |
# Create mask of pixels with target color | |
mask = np.all(img_arr == target_color, axis=-1) | |
# Convert mask to binary matrix | |
binary_matrix = mask.astype(int) | |
from datetime import datetime | |
binary_file_name = f'mask-{datetime.now().timestamp()}.png' | |
cv2.imwrite(binary_file_name, binary_matrix * 255) | |
#binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) | |
return binary_file_name |