brayden-gg commited on
Commit
9c37eb2
1 Parent(s): 3d3e7ab

improved speed of char interpolation

Browse files
Files changed (2) hide show
  1. app.py +116 -119
  2. convenience.py +17 -27
app.py CHANGED
@@ -8,7 +8,6 @@ from DataLoader import DataLoader
8
  import convenience
9
  import gradio as gr
10
 
11
- #@title Demo
12
  device = 'cpu'
13
  num_samples = 10
14
 
@@ -16,42 +15,51 @@ net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
16
 
17
  if not torch.cuda.is_available():
18
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"])
19
-
20
 
21
  dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
22
 
23
 
24
  writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151]
25
  all_loaded_data = []
 
26
  avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &"
27
  avail_char_list = avail_char.split(" ")
28
- for writer_id in [120, 80]:
29
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
30
  all_loaded_data.append(loaded_data)
31
 
32
  default_loaded_data = all_loaded_data[-1]
33
- mdn_words = []
34
- mdn_mean_Ws = []
35
- all_word_mdn_Ws = []
36
- all_word_mdn_Cs = []
37
 
38
  # data for writer interpolation
39
- writer_words = []
40
  writer_mean_Ws = []
41
  all_word_writer_Ws = []
42
  all_word_writer_Cs = []
43
- weight = 0.7
 
44
 
45
- def update_target_word(target_word):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  writer_words.clear()
47
  for word in target_word.split(" "):
48
  writer_words.append(word)
49
 
50
- writer_mean_Ws.clear()
51
- for loaded_data in all_loaded_data:
52
- mean_global_W = convenience.get_mean_global_W(net, loaded_data, device)
53
- writer_mean_Ws.append(mean_global_W)
54
-
55
  all_word_writer_Ws.clear()
56
  all_word_writer_Cs.clear()
57
  for word in writer_words:
@@ -59,82 +67,75 @@ def update_target_word(target_word):
59
  all_word_writer_Ws.append(all_writer_Ws)
60
  all_word_writer_Cs.append(all_writer_Cs)
61
 
62
- return update_writer_slider(weight)
63
-
64
 
65
  # for writer interpolation
66
  def update_writer_slider(val):
67
- global weight
68
- weight = val
 
 
69
  net.clamp_mdn = 0
70
- im = convenience.draw_words(writer_words, all_word_writer_Ws, all_word_writer_Cs, [1 - weight, weight], net)
71
  return im.convert("RGB")
72
 
 
73
  def update_chosen_writers(writer1, writer2):
74
  net.clamp_mdn = 0
75
- id1, id2 = int(writer1.split(" ")[1]), int(writer2.split(" ")[1])
76
  all_loaded_data.clear()
77
- for writer_id in [id1, id2]:
78
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
79
  all_loaded_data.append(loaded_data)
80
 
81
- return gr.Slider.update(label=f"{writer1} vs. {writer2}"), update_writer_slider(weight)
 
 
 
 
 
82
 
83
  # for character blend
84
- def interpolate_chars(c1, c2, weight):
85
- """Generates an image of handwritten text based on target_sentence"""
86
 
 
 
 
87
  net.clamp_mdn = 0
88
 
89
- letters = [c1, c2]
 
90
  character_weights = [1 - weight, weight]
91
- M = len(letters)
92
- mean_global_W = convenience.get_mean_global_W(net, all_loaded_data[0], device)
93
-
94
- all_Cs = torch.zeros(1, M, convenience.L, convenience.L)
95
- for i in range(M): # get corners of grid
96
- W_vector, char_matrix = convenience.get_DSD(net, letters[i], [mean_global_W], [default_loaded_data], device)
97
- all_Cs[:, i, :, :] = char_matrix
98
-
99
- all_Ws = mean_global_W.reshape(1, 1, convenience.L)
100
-
101
- all_W_c = convenience.get_character_blend_W_c(character_weights, all_Ws, all_Cs)
102
- all_commands = convenience.get_commands(net, letters[0], all_W_c)
103
-
104
- width = 60
105
- x_offset = 325
106
- im = Image.fromarray(np.zeros([160, 750]))
107
- dr = ImageDraw.Draw(im)
108
- for [x, y, t] in all_commands:
109
- if t == 0:
110
- dr.line((
111
- px + width/2 + x_offset,
112
- py - width/2, # letters are shifted down for some reason
113
- x + width/2 + + x_offset,
114
- y - width/2), 255, 1)
115
- px, py = x, y
116
-
117
-
118
  return im.convert("RGB")
119
 
120
- def choose_blend_chars(c1, c2):
121
- return gr.Slider.update(label=f"'{c1}' vs. '{c2}'")
 
 
 
 
 
 
 
 
122
 
123
  # for MDN
124
 
 
125
  def update_mdn_word(target_word):
126
  mdn_words.clear()
127
  for word in target_word.split(" "):
128
  mdn_words.append(word)
129
 
130
- mdn_mean_Ws.clear()
131
- mean_global_W = convenience.get_mean_global_W(net, default_loaded_data, device)
132
- mdn_mean_Ws.append(mean_global_W)
133
-
134
  all_word_mdn_Ws.clear()
135
  all_word_mdn_Cs.clear()
136
  for word in mdn_words:
137
- all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, mdn_mean_Ws, [default_loaded_data], device)
138
  all_word_mdn_Ws.append(all_writer_Ws)
139
  all_word_mdn_Cs.append(all_writer_Cs)
140
 
@@ -148,76 +149,72 @@ def sample_mdn(maxs, maxr):
148
  return im.convert("RGB")
149
 
150
 
151
- update_target_word("hello world")
152
- update_mdn_word("hello world")
 
 
 
 
153
 
154
  with gr.Blocks() as demo:
155
  with gr.Tabs():
156
  with gr.TabItem("Blend Writers"):
157
- target_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1)
158
- with gr.Row():
159
- left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0]
160
- right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1]
161
- with gr.Column():
162
- writer1 = gr.Radio(left_ratio_options, value="Style 120", label="Style for first writer")
163
- with gr.Column():
164
- writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer")
165
- with gr.Row():
166
- writer_slider = gr.Slider(0, 1, value=0.7, label="Style 120 vs. Style 80")
167
- with gr.Row():
168
- writer_submit = gr.Button("Submit")
169
- with gr.Row():
170
- writer_default_image = convenience.sample_blended_writers([0.3, 0.7], "hello world", net, all_loaded_data, device).convert("RGB")
171
- writer_output = gr.Image(writer_default_image)
172
-
173
- writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output])
174
- writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output])
175
- target_word.submit(fn=update_target_word, inputs=[target_word], outputs=[writer_output])
176
-
177
- writer1.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output])
178
- writer2.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output])
179
-
180
 
181
  with gr.TabItem("Blend Characters"):
182
  with gr.Row():
183
  with gr.Column():
184
- char1 = gr.Dropdown(choices=avail_char_list, value="y", label="Character 1")
185
  with gr.Column():
186
- char2 = gr.Dropdown(choices=avail_char_list, value="s", label="Character 2")
187
  with gr.Row():
188
- char_slider = gr.Slider(0, 1, value=0.7, label="'y' vs. 's'")
189
  with gr.Row():
190
- char_default_image = convenience.sample_blended_chars([0.3, 0.7], ["y", "s"], net, [default_loaded_data], device).convert("RGB")
191
- char_output = gr.Image(char_default_image)
192
-
193
- char_slider.change(fn=interpolate_chars, inputs=[char1, char2, char_slider], outputs=[char_output])
194
-
195
- char1.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider])
196
- char2.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider])
197
-
198
-
199
  with gr.TabItem("Add Randomness"):
200
- mdn_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1)
201
- '''
202
- with gr.Row():
203
- radio_options3 = ["Writer " + str(n) for n in writer_options]
204
- writer = gr.Radio(radio_options3, value="Writer 80", label="Style for Writer")
205
- writer.change(fn=new_writer_mdn, inputs=[writer, slider3, slider4], outputs=[output])
206
- '''
207
- with gr.Row():
208
- with gr.Column():
209
- max_rand = gr.Slider(0, 1, value=1, label="Maximum Randomness")
210
- with gr.Column():
211
- scale_rand = gr.Slider(0, 3, value=0.5, label="Scale of Randomness")
212
- with gr.Row():
213
- mdn_sample_button = gr.Button(value="Resample!")
214
- with gr.Row():
215
- default_im = convenience.mdn_single_sample("hello world", 0.5, 1, net, [default_loaded_data], device).convert('RGB')
216
- mdn_output = gr.Image(default_im)
217
-
218
- max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
219
- scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
220
- mdn_sample_button.click(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
221
- mdn_word.submit(fn=update_mdn_word, inputs=[mdn_word], outputs=[mdn_output])
222
 
223
  demo.launch()
 
8
  import convenience
9
  import gradio as gr
10
 
 
11
  device = 'cpu'
12
  num_samples = 10
13
 
 
15
 
16
  if not torch.cuda.is_available():
17
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"])
18
+
19
 
20
  dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
21
 
22
 
23
  writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151]
24
  all_loaded_data = []
25
+ chosen_writers = [120, 80]
26
  avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &"
27
  avail_char_list = avail_char.split(" ")
28
+ for writer_id in chosen_writers:
29
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
30
  all_loaded_data.append(loaded_data)
31
 
32
  default_loaded_data = all_loaded_data[-1]
 
 
 
 
33
 
34
  # data for writer interpolation
35
+ writer_words = ["hello", "world"]
36
  writer_mean_Ws = []
37
  all_word_writer_Ws = []
38
  all_word_writer_Cs = []
39
+ writer_weight = 0.7
40
+
41
 
42
+ # data for char interpolation
43
+ blend_chars = ["y", "s"]
44
+ char_mean_global_W = None
45
+ char_weight = 0.7
46
+ default_mean_global_W = convenience.get_mean_global_W(net, default_loaded_data, device)
47
+ char_Ws = default_mean_global_W.reshape(1, 1, convenience.L)
48
+ char_Cs = all_Cs = torch.zeros(1, 2, convenience.L, convenience.L)
49
+
50
+
51
+ # data for MDN
52
+ mdn_words = ["hello", "world"]
53
+ mdn_mean_global_W = None
54
+ all_word_mdn_Ws = []
55
+ all_word_mdn_Cs = []
56
+
57
+
58
+ def update_writer_word(target_word):
59
  writer_words.clear()
60
  for word in target_word.split(" "):
61
  writer_words.append(word)
62
 
 
 
 
 
 
63
  all_word_writer_Ws.clear()
64
  all_word_writer_Cs.clear()
65
  for word in writer_words:
 
67
  all_word_writer_Ws.append(all_writer_Ws)
68
  all_word_writer_Cs.append(all_writer_Cs)
69
 
70
+ return update_writer_slider(writer_weight)
71
+
72
 
73
  # for writer interpolation
74
  def update_writer_slider(val):
75
+ global writer_weight
76
+ writer_weight = val
77
+ weights = [1 - writer_weight, writer_weight]
78
+
79
  net.clamp_mdn = 0
80
+ im = convenience.draw_words(writer_words, all_word_writer_Ws, all_word_writer_Cs, weights, net)
81
  return im.convert("RGB")
82
 
83
+
84
  def update_chosen_writers(writer1, writer2):
85
  net.clamp_mdn = 0
86
+ chosen_writers[0], chosen_writers[1] = int(writer1.split(" ")[1]), int(writer2.split(" ")[1])
87
  all_loaded_data.clear()
88
+ for writer_id in chosen_writers:
89
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples)))
90
  all_loaded_data.append(loaded_data)
91
 
92
+ writer_mean_Ws.clear()
93
+ for loaded_data in all_loaded_data:
94
+ mean_global_W = convenience.get_mean_global_W(net, loaded_data, device)
95
+ writer_mean_Ws.append(mean_global_W)
96
+
97
+ return gr.Slider.update(label=f"{writer1} vs. {writer2}"), update_writer_slider(writer_weight)
98
 
99
  # for character blend
 
 
100
 
101
+
102
+ def update_char_slider(weight):
103
+ """Generates an image of handwritten text based on target_sentence"""
104
  net.clamp_mdn = 0
105
 
106
+ global char_weight
107
+ char_weight = weight
108
  character_weights = [1 - weight, weight]
109
+
110
+ all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
111
+ all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
112
+ im = convenience.commands_to_image(all_commands, 160, 750, 375, 30)
113
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return im.convert("RGB")
115
 
116
+
117
+ def update_blend_chars(c1, c2):
118
+ global blend_chars
119
+ blend_chars[0], blend_chars[1] = c1, c2
120
+
121
+ for i in range(2): # get corners of grid
122
+ _, char_matrix = convenience.get_DSD(net, blend_chars[i], default_mean_global_W, [default_loaded_data], device)
123
+ char_Cs[:, i, :, :] = char_matrix
124
+
125
+ return gr.Slider.update(label=f"'{c1}' vs. '{c2}'")
126
 
127
  # for MDN
128
 
129
+
130
  def update_mdn_word(target_word):
131
  mdn_words.clear()
132
  for word in target_word.split(" "):
133
  mdn_words.append(word)
134
 
 
 
 
 
135
  all_word_mdn_Ws.clear()
136
  all_word_mdn_Cs.clear()
137
  for word in mdn_words:
138
+ all_writer_Ws, all_writer_Cs = convenience.get_DSD(net, word, default_mean_global_W, [default_loaded_data], device)
139
  all_word_mdn_Ws.append(all_writer_Ws)
140
  all_word_mdn_Cs.append(all_writer_Cs)
141
 
 
149
  return im.convert("RGB")
150
 
151
 
152
+ update_writer_word(" ".join(writer_words))
153
+ update_chosen_writers(f"Writer {chosen_writers[0]}", f"Writer {chosen_writers[1]}")
154
+
155
+ update_mdn_word(" ".join(writer_words))
156
+
157
+ update_blend_chars(*blend_chars)
158
 
159
  with gr.Blocks() as demo:
160
  with gr.Tabs():
161
  with gr.TabItem("Blend Writers"):
162
+ target_word = gr.Textbox(label="Target Word", value=" ".join(writer_words), max_lines=1)
163
+ with gr.Row():
164
+ left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0]
165
+ right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1]
166
+ with gr.Column():
167
+ writer1 = gr.Radio(left_ratio_options, value="Style 120", label="Style for first writer")
168
+ with gr.Column():
169
+ writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer")
170
+ with gr.Row():
171
+ writer_slider = gr.Slider(0, 1, value=writer_weight, label="Style 120 vs. Style 80")
172
+ with gr.Row():
173
+ writer_submit = gr.Button("Submit")
174
+ with gr.Row():
175
+ writer_default_image = update_writer_slider(writer_weight)
176
+ writer_output = gr.Image(writer_default_image)
177
+
178
+ writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output])
179
+ writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output])
180
+ target_word.submit(fn=update_writer_word, inputs=[target_word], outputs=[writer_output])
181
+
182
+ writer1.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output])
183
+ writer2.change(fn=update_chosen_writers, inputs=[writer1, writer2], outputs=[writer_slider, writer_output])
 
184
 
185
  with gr.TabItem("Blend Characters"):
186
  with gr.Row():
187
  with gr.Column():
188
+ char1 = gr.Dropdown(choices=avail_char_list, value=blend_chars[0], label="Character 1")
189
  with gr.Column():
190
+ char2 = gr.Dropdown(choices=avail_char_list, value=blend_chars[1], label="Character 2")
191
  with gr.Row():
192
+ char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
193
  with gr.Row():
194
+ char_default_image = update_char_slider(char_weight)
195
+ char_output = gr.Image(char_default_image)
196
+
197
+ char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output])
198
+
199
+ char1.change(fn=update_blend_chars, inputs=[char1, char2], outputs=[char_slider])
200
+ char2.change(fn=update_blend_chars, inputs=[char1, char2], outputs=[char_slider])
201
+
 
202
  with gr.TabItem("Add Randomness"):
203
+ mdn_word = gr.Textbox(label="Target Word", value=" ".join(mdn_words), max_lines=1)
204
+ with gr.Row():
205
+ with gr.Column():
206
+ max_rand = gr.Slider(0, 1, value=net.clamp_mdn, label="Maximum Randomness")
207
+ with gr.Column():
208
+ scale_rand = gr.Slider(0, 3, value=net.scale_sd, label="Scale of Randomness")
209
+ with gr.Row():
210
+ mdn_sample_button = gr.Button(value="Resample!")
211
+ with gr.Row():
212
+ default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
213
+ mdn_output = gr.Image(default_im)
214
+
215
+ max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
216
+ scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
217
+ mdn_sample_button.click(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output])
218
+ mdn_word.submit(fn=update_mdn_word, inputs=[mdn_word], outputs=[mdn_output])
 
 
 
 
 
 
219
 
220
  demo.launch()
convenience.py CHANGED
@@ -455,19 +455,7 @@ def sample_blended_chars(character_weights, letters, net, all_loaded_data, devic
455
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
456
  all_commands = get_commands(net, letters[0], all_W_c)
457
 
458
- width = 60
459
- im = Image.fromarray(np.zeros([100, 100]))
460
- dr = ImageDraw.Draw(im)
461
- for [x, y, t] in all_commands:
462
- if t == 0:
463
- dr.line((
464
- px + width/2,
465
- py - width/2, # letters are shifted down for some reason
466
- x + width/2,
467
- y - width/2), 255, 1)
468
- px, py = x, y
469
-
470
-
471
  return im
472
 
473
 
@@ -494,19 +482,7 @@ def char_interpolation_video(letters, transition_time, net, all_loaded_data, dev
494
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
495
  all_commands = get_commands(net, letters[i], all_W_c)
496
 
497
- im = Image.fromarray(np.zeros([100, 100]))
498
- dr = ImageDraw.Draw(im)
499
- width = 50
500
- for [x, y, t] in all_commands:
501
- if t == 0:
502
- dr.line((
503
- px + width/2,
504
- py - width/2, # letters are shifted down for some reason
505
- x + width/2,
506
- y - width/2), 255, 1)
507
- px, py = x, y
508
-
509
-
510
  im.convert("RGB").save(f"results/{''.join(letters)}_frames/frames_{str(i * transition_time + j).zfill(3)}.png")
511
 
512
  # Convert fromes to video using ffmpeg
@@ -529,4 +505,18 @@ def draw_words(words, word_Ws, word_Cs, writer_weights, net):
529
  px, py = x, y
530
  width += np.max(all_commands[:, 0]) + 25
531
 
532
- return im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
456
  all_commands = get_commands(net, letters[0], all_W_c)
457
 
458
+ im = commands_to_image(all_commands, 100, 100, 30, 30)
 
 
 
 
 
 
 
 
 
 
 
 
459
  return im
460
 
461
 
 
482
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
483
  all_commands = get_commands(net, letters[i], all_W_c)
484
 
485
+ im = commands_to_image(all_commands, 100, 100, 25, 25)
 
 
 
 
 
 
 
 
 
 
 
 
486
  im.convert("RGB").save(f"results/{''.join(letters)}_frames/frames_{str(i * transition_time + j).zfill(3)}.png")
487
 
488
  # Convert fromes to video using ffmpeg
 
505
  px, py = x, y
506
  width += np.max(all_commands[:, 0]) + 25
507
 
508
+ return im
509
+
510
+ def commands_to_image(commands, imW, imH, xoff, yoff):
511
+ im = Image.fromarray(np.zeros([imW, imH]))
512
+ dr = ImageDraw.Draw(im)
513
+ for [x, y, t] in commands:
514
+ if t == 0:
515
+ dr.line((
516
+ px + xoff,
517
+ py - yoff, # letters are shifted down for some reason
518
+ x + xoff,
519
+ y - yoff), 255, 1)
520
+ px, py = x, y
521
+ return im
522
+