colonelwatch commited on
Commit
9faa360
·
1 Parent(s): 6b02d73

Add initial app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +145 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage import exposure, color, util
3
+ from matplotlib import pyplot as plt
4
+ import gradio as gr
5
+
6
+ # https://en.wikipedia.org/wiki/Rotation_matrix#General_rotations
7
+ def _rotation_matrix(yaw, pitch, roll):
8
+ yaw_matrix = np.array([
9
+ [np.cos(yaw), -np.sin(yaw), 0],
10
+ [np.sin(yaw), np.cos(yaw), 0],
11
+ [0, 0, 1],
12
+ ])
13
+ pitch_matrix = np.array([
14
+ [np.cos(pitch), 0, np.sin(pitch)],
15
+ [0, 1, 0],
16
+ [-np.sin(pitch), 0, np.cos(pitch)],
17
+ ])
18
+ roll_matrix = np.array([
19
+ [1, 0, 0],
20
+ [0, np.cos(roll), -np.sin(roll)],
21
+ [0, np.sin(roll), np.cos(roll)],
22
+ ])
23
+
24
+ return yaw_matrix @ pitch_matrix @ roll_matrix
25
+
26
+ def _calculate_transform():
27
+ t_cie = np.array([50, 0, 0]) # center of CIELAB color space
28
+ # lightness axis in CIELAB space is spanned by the vector [1, 0, 0]
29
+ t_sol = np.array([55.5, -6.125, -2.875]) # center of Solarized base palette in CIELAB space
30
+ v_sol = np.array([0.951, 0.145, 0.272]) # principal component of Solarized base palette in CIELAB space
31
+
32
+ # find the rotation matrix that rotates [1, 0, 0] to v_sol
33
+ pitch = -np.arcsin(v_sol[2])
34
+ yaw = np.arcsin(v_sol[1]/np.cos(pitch))
35
+ roll = 0 # roll is a free parameter
36
+ R = _rotation_matrix(yaw, pitch, roll)
37
+
38
+ def rotate(x):
39
+ return (x-t_cie) @ R.T + t_sol
40
+
41
+ return rotate
42
+
43
+ transform = _calculate_transform()
44
+
45
+ # light_min and light_max define a range of lightnesss between 0 and 100
46
+ # chroma_attenuation is a factor between 0 and 1
47
+ def preprocess_image(image, light_min, light_max, chroma_attenutation):
48
+ lightness_range = (light_min, light_max)
49
+ chroma_range = (-128*chroma_attenutation, 128*chroma_attenutation)
50
+
51
+ image_lab = color.rgb2lab(image)
52
+ image_lab[:, :, 0] = exposure.rescale_intensity(image_lab[:, :, 0], in_range=(0, 100), out_range=lightness_range)
53
+ image_lab[:, :, 1] = exposure.rescale_intensity(image_lab[:, :, 1], in_range=(-128, 128), out_range=chroma_range)
54
+ image_lab[:, :, 2] = exposure.rescale_intensity(image_lab[:, :, 2], in_range=(-128, 128), out_range=chroma_range)
55
+ image = color.lab2rgb(image_lab)
56
+
57
+ return image
58
+
59
+ def preprocess_image_parallel(image, light_min, light_max, chroma_attenutation):
60
+ preprocess_kwargs = {'light_min': light_min, 'light_max': light_max, 'chroma_attenutation': chroma_attenutation}
61
+ image = util.apply_parallel(
62
+ preprocess_image, image,
63
+ (1024, 1024), # restricted chunk size to prevent OOM-kill
64
+ dtype=np.float64, # required, according to error message
65
+ extra_keywords=preprocess_kwargs,
66
+ channel_axis=2, # third axis holds RGB channels
67
+ )
68
+ return image
69
+
70
+ def lightness_hist(image):
71
+ fig = plt.figure(figsize=(12, 12/5)) # set aspect ratio of figure to 5:1
72
+ ax = fig.add_subplot()
73
+
74
+ image_lightness = color.rgb2lab(image)[:, :, 0].flatten()
75
+
76
+ ax.hist(image_lightness, bins=64, label=None)
77
+ ax.axvline(x=8.13974087, color='#586e75', label='Solarized dark target range')
78
+ ax.axvline(x=59.4372606, color='#586e75', label=None)
79
+ ax.axvline(x=38.76215165, color='#93a1a1', label='Solarized light target range')
80
+ ax.axvline(x=93.86995897, color='#93a1a1', label=None)
81
+ ax.set_xlim(0, 100)
82
+
83
+ ax.legend()
84
+ ax.set_xlabel('Lightness')
85
+ ax.set_ylabel('Frequency')
86
+
87
+ # set aspect ratio of final plot to 7:1 (different from figure aspect ratio to fit other elements)
88
+ x_left, x_right = ax.get_xlim()
89
+ y_bottom, y_top = ax.get_ylim()
90
+ ax.set_aspect((x_right-x_left)/(y_top-y_bottom)/7)
91
+
92
+ return fig
93
+
94
+ def transform_image(image):
95
+ shape = image.shape # record shape
96
+
97
+ workmem = color.rgb2lab(image) # convert to CIELAB
98
+ workmem = workmem.reshape(-1, 3)
99
+ workmem = transform(workmem) # transform is a function defined globally
100
+ workmem = workmem.reshape(shape) # undo flatten
101
+ workmem = color.lab2rgb(workmem) # convert back to RGB
102
+
103
+ workmem = util.img_as_ubyte(workmem) # convert back to uint8 rgb
104
+
105
+ return workmem
106
+
107
+ def transform_image_parallel(image):
108
+ image = util.apply_parallel(
109
+ transform_image, image,
110
+ (1024, 1024), # restricted chunk size to prevent OOM-kill
111
+ dtype=np.uint8, # required, according to error message
112
+ channel_axis=2, # third axis holds RGB channels
113
+ )
114
+ return image
115
+
116
+ with gr.Blocks() as demo:
117
+ with gr.Row():
118
+ with gr.Column(scale=1, min_width=320):
119
+ input_image = gr.Image(label='Input')
120
+ light_min_slider = gr.Slider(minimum=0, maximum=100, value=10, label='Lightness minimum')
121
+ light_max_slider = gr.Slider(minimum=0, maximum=100, value=70, label='Lightness maximum')
122
+ chroma_attenutation_slider = gr.Slider(minimum=0, maximum=1, value=0.25, label='Chroma attenuation')
123
+ preprocess_button = gr.Button(value='Preprocess into workspace')
124
+ transform_button = gr.Button(value='Transform workspace')
125
+ with gr.Column(scale=2, min_width=640):
126
+ workspace_image = gr.Image(label='Workspace', interactive=False)
127
+ hist = gr.Plot(label='Lightness histogram')
128
+
129
+ preprocess_button.click(
130
+ preprocess_image_parallel,
131
+ inputs=[input_image, light_min_slider, light_max_slider, chroma_attenutation_slider],
132
+ outputs=[workspace_image]
133
+ ).then(
134
+ lightness_hist,
135
+ inputs=[workspace_image],
136
+ outputs=[hist]
137
+ )
138
+
139
+ transform_button.click(
140
+ transform_image_parallel,
141
+ inputs=[workspace_image],
142
+ outputs=[workspace_image]
143
+ )
144
+
145
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ scikit-image
3
+ matplotlib