brayden-gg commited on
Commit
904e4e5
·
1 Parent(s): 1a69cb2

switched to SVG rendering

Browse files
app.py CHANGED
@@ -76,8 +76,8 @@ def update_writer_slider(val):
76
  weights = [1 - writer_weight, writer_weight]
77
 
78
  net.clamp_mdn = 0
79
- im = convenience.draw_words(writer_words, all_word_writer_Ws, all_word_writer_Cs, weights, net)
80
- return im.convert("RGB")
81
 
82
 
83
  def update_chosen_writers(writer1, writer2):
@@ -109,9 +109,9 @@ def update_char_slider(weight):
109
 
110
  all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
111
  all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
112
- im = convenience.commands_to_image(all_commands, 160, 750, 375, 30)
113
 
114
- return im.convert("RGB")
115
 
116
 
117
  def update_blend_chars(c1, c2):
@@ -145,8 +145,8 @@ def update_mdn_word(target_word):
145
  def sample_mdn(maxs, maxr):
146
  net.clamp_mdn = maxr
147
  net.scale_sd = maxs
148
- im = convenience.draw_words(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net)
149
- return im.convert("RGB")
150
 
151
 
152
  update_writer_word(" ".join(writer_words))
@@ -173,7 +173,7 @@ with gr.Blocks() as demo:
173
  writer_submit = gr.Button("Submit")
174
  with gr.Row():
175
  writer_default_image = update_writer_slider(writer_weight)
176
- writer_output = gr.Image(writer_default_image)
177
 
178
  writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
179
  writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
@@ -192,7 +192,7 @@ with gr.Blocks() as demo:
192
  char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
193
  with gr.Row():
194
  char_default_image = update_char_slider(char_weight)
195
- char_output = gr.Image(char_default_image)
196
 
197
  char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output], show_progress=False)
198
 
@@ -210,7 +210,7 @@ with gr.Blocks() as demo:
210
  mdn_sample_button = gr.Button(value="Resample!")
211
  with gr.Row():
212
  default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
213
- mdn_output = gr.Image(default_im)
214
 
215
  max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
216
  scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
 
76
  weights = [1 - writer_weight, writer_weight]
77
 
78
  net.clamp_mdn = 0
79
+ svg = convenience.draw_words_svg(writer_words, all_word_writer_Ws, all_word_writer_Cs, weights, net)
80
+ return svg
81
 
82
 
83
  def update_chosen_writers(writer1, writer2):
 
109
 
110
  all_W_c = convenience.get_character_blend_W_c(character_weights, char_Ws, char_Cs)
111
  all_commands = convenience.get_commands(net, blend_chars[0], all_W_c)
112
+ svg = convenience.commands_to_svg(all_commands, 750, 160, 375)
113
 
114
+ return svg
115
 
116
 
117
  def update_blend_chars(c1, c2):
 
145
  def sample_mdn(maxs, maxr):
146
  net.clamp_mdn = maxr
147
  net.scale_sd = maxs
148
+ svg = convenience.draw_words_svg(mdn_words, all_word_mdn_Ws, all_word_mdn_Cs, [1], net)
149
+ return svg
150
 
151
 
152
  update_writer_word(" ".join(writer_words))
 
173
  writer_submit = gr.Button("Submit")
174
  with gr.Row():
175
  writer_default_image = update_writer_slider(writer_weight)
176
+ writer_output = gr.HTML(writer_default_image)
177
 
178
  writer_submit.click(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
179
  writer_slider.change(fn=update_writer_slider, inputs=[writer_slider], outputs=[writer_output], show_progress=False)
 
192
  char_slider = gr.Slider(0, 1, value=char_weight, label=f"'{blend_chars[0]}' vs. '{blend_chars[1]}'")
193
  with gr.Row():
194
  char_default_image = update_char_slider(char_weight)
195
+ char_output = gr.HTML(char_default_image)
196
 
197
  char_slider.change(fn=update_char_slider, inputs=[char_slider], outputs=[char_output], show_progress=False)
198
 
 
210
  mdn_sample_button = gr.Button(value="Resample!")
211
  with gr.Row():
212
  default_im = sample_mdn(net.scale_sd, net.clamp_mdn)
213
+ mdn_output = gr.HTML(default_im)
214
 
215
  max_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
216
  scale_rand.change(fn=sample_mdn, inputs=[scale_rand, max_rand], outputs=[mdn_output], show_progress=False)
config/__pycache__/GlobalVariables.cpython-38.pyc CHANGED
Binary files a/config/__pycache__/GlobalVariables.cpython-38.pyc and b/config/__pycache__/GlobalVariables.cpython-38.pyc differ
 
config/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/config/__pycache__/__init__.cpython-38.pyc and b/config/__pycache__/__init__.cpython-38.pyc differ
 
convenience.py CHANGED
@@ -14,10 +14,12 @@ from config.GlobalVariables import *
14
  from tensorboardX import SummaryWriter
15
  from SynthesisNetwork import SynthesisNetwork
16
  from DataLoader import DataLoader
 
17
  # import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python
18
 
19
  L = 256
20
 
 
21
  def get_mean_global_W(net, loaded_data, device):
22
  """gets the mean global style vector for a given writer"""
23
  [_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
@@ -231,14 +233,14 @@ def get_character_blend_W_c(character_weights, all_Ws, all_Cs):
231
  W_vector = all_Ws[0, 0, :].unsqueeze(-1)
232
 
233
  weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions
234
- char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze() # take weighted sum accross characters axis
235
 
236
  W_c = char_matrix @ W_vector
237
 
238
  return W_c.reshape(1, 1, L)
239
 
240
 
241
- def get_commands(net, target_word, all_W_c): # seems like target_word is only used for length
242
  """converts character-dependent style-dependent DSDs to a list of commands for drawing"""
243
  all_commands = []
244
  current_id = 0
@@ -285,6 +287,7 @@ def get_commands(net, target_word, all_W_c): # seems like target_word is only us
285
 
286
  return commands
287
 
 
288
  def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device):
289
  '''
290
  Method creating gif of mdn samples
@@ -306,7 +309,7 @@ def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_dat
306
  writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
307
  word_Ws.append(writer_Ws)
308
  word_Cs.append(writer_Cs)
309
-
310
  im = draw_words(words, word_Ws, word_Cs, [1], net)
311
  im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png')
312
  # Convert fromes to video using ffmpeg
@@ -314,6 +317,7 @@ def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_dat
314
  videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p")
315
  videos.run(overwrite_output=True)
316
 
 
317
  def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"):
318
  """Generates an image of handwritten text based on target_sentence"""
319
  words = target_sentence.split(' ')
@@ -329,7 +333,7 @@ def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data
329
  writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
330
  word_Ws.append(writer_Ws)
331
  word_Cs.append(writer_Cs)
332
-
333
  return draw_words(words, word_Ws, word_Cs, writer_weights, net)
334
 
335
 
@@ -356,10 +360,10 @@ def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"
356
  wx = i / (grid_size - 1)
357
  wy = j / (grid_size - 1)
358
 
359
- character_weights = [(1 - wx) * (1 - wy), # top left is 1 at (0, 0)
360
- wx * (1 - wy), # top right is 1 at (1, 0)
361
  (1 - wx) * wy, # bottom left is 1 at (0, 1)
362
- wx * wy] # bottom right is 1 at (1, 1)
363
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
364
  all_commands = get_commands(net, letters[0], all_W_c)
365
 
@@ -377,6 +381,7 @@ def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"
377
 
378
  return im
379
 
 
380
  def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"):
381
  """
382
  Generates a video of interpolating between each provided writer
@@ -416,6 +421,7 @@ def writer_interpolation_video(target_sentence, transition_time, net, all_loaded
416
  videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p")
417
  videos.run(overwrite_output=True)
418
 
 
419
  def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device):
420
  '''
421
  Method creating gif of mdn samples
@@ -462,7 +468,7 @@ def sample_blended_chars(character_weights, letters, net, all_loaded_data, devic
462
  def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"):
463
  """Generates an image of handwritten text based on target_sentence"""
464
 
465
- os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True) # make a folder for the frames
466
 
467
  M = len(letters)
468
  mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
@@ -507,6 +513,25 @@ def draw_words(words, word_Ws, word_Cs, writer_weights, net):
507
 
508
  return im
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  def commands_to_image(commands, imW, imH, xoff, yoff):
511
  im = Image.fromarray(np.zeros([imW, imH]))
512
  dr = ImageDraw.Draw(im)
@@ -519,4 +544,15 @@ def commands_to_image(commands, imW, imH, xoff, yoff):
519
  y - yoff), 255, 1)
520
  px, py = x, y
521
  return im
522
-
 
 
 
 
 
 
 
 
 
 
 
 
14
  from tensorboardX import SummaryWriter
15
  from SynthesisNetwork import SynthesisNetwork
16
  from DataLoader import DataLoader
17
+ import svgwrite
18
  # import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python
19
 
20
  L = 256
21
 
22
+
23
  def get_mean_global_W(net, loaded_data, device):
24
  """gets the mean global style vector for a given writer"""
25
  [_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out,
 
233
  W_vector = all_Ws[0, 0, :].unsqueeze(-1)
234
 
235
  weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions
236
+ char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze() # take weighted sum accross characters axis
237
 
238
  W_c = char_matrix @ W_vector
239
 
240
  return W_c.reshape(1, 1, L)
241
 
242
 
243
+ def get_commands(net, target_word, all_W_c): # seems like target_word is only used for length
244
  """converts character-dependent style-dependent DSDs to a list of commands for drawing"""
245
  all_commands = []
246
  current_id = 0
 
287
 
288
  return commands
289
 
290
+
291
  def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device):
292
  '''
293
  Method creating gif of mdn samples
 
309
  writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device)
310
  word_Ws.append(writer_Ws)
311
  word_Cs.append(writer_Cs)
312
+
313
  im = draw_words(words, word_Ws, word_Cs, [1], net)
314
  im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png')
315
  # Convert fromes to video using ffmpeg
 
317
  videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p")
318
  videos.run(overwrite_output=True)
319
 
320
+
321
  def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"):
322
  """Generates an image of handwritten text based on target_sentence"""
323
  words = target_sentence.split(' ')
 
333
  writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device)
334
  word_Ws.append(writer_Ws)
335
  word_Cs.append(writer_Cs)
336
+
337
  return draw_words(words, word_Ws, word_Cs, writer_weights, net)
338
 
339
 
 
360
  wx = i / (grid_size - 1)
361
  wy = j / (grid_size - 1)
362
 
363
+ character_weights = [(1 - wx) * (1 - wy), # top left is 1 at (0, 0)
364
+ wx * (1 - wy), # top right is 1 at (1, 0)
365
  (1 - wx) * wy, # bottom left is 1 at (0, 1)
366
+ wx * wy] # bottom right is 1 at (1, 1)
367
  all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs)
368
  all_commands = get_commands(net, letters[0], all_W_c)
369
 
 
381
 
382
  return im
383
 
384
+
385
  def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"):
386
  """
387
  Generates a video of interpolating between each provided writer
 
421
  videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p")
422
  videos.run(overwrite_output=True)
423
 
424
+
425
  def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device):
426
  '''
427
  Method creating gif of mdn samples
 
468
  def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"):
469
  """Generates an image of handwritten text based on target_sentence"""
470
 
471
+ os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True) # make a folder for the frames
472
 
473
  M = len(letters)
474
  mean_global_W = get_mean_global_W(net, all_loaded_data[0], device)
 
513
 
514
  return im
515
 
516
+
517
+ def draw_words_svg(words, word_Ws, word_Cs, writer_weights, net):
518
+ dwg = svgwrite.Drawing("output.svg", size=(750, 160), style="background-color: black;")
519
+ width = 50
520
+ for word, all_writer_Ws, all_writer_Cs in zip(words, word_Ws, word_Cs):
521
+ all_W_c = get_writer_blend_W_c(writer_weights, all_writer_Ws, all_writer_Cs)
522
+ all_commands = get_commands(net, word, all_W_c)
523
+
524
+ for [x, y, t] in all_commands:
525
+ if t == 0:
526
+ path.push("L", x + width, y)
527
+ else:
528
+ path = svgwrite.path.Path(stroke="white", stroke_width="1")
529
+ dwg.add(path)
530
+ path.push("M", x + width, y)
531
+ width += np.max(all_commands[:, 0]) + 25
532
+ return dwg.tostring()
533
+
534
+
535
  def commands_to_image(commands, imW, imH, xoff, yoff):
536
  im = Image.fromarray(np.zeros([imW, imH]))
537
  dr = ImageDraw.Draw(im)
 
544
  y - yoff), 255, 1)
545
  px, py = x, y
546
  return im
547
+
548
+
549
+ def commands_to_svg(commands, imW, imH, xoff):
550
+ dwg = svgwrite.Drawing("output.svg", size=(imW, imH), style="background-color:black")
551
+ for [x, y, t] in commands:
552
+ if t == 0:
553
+ path.push("L", x + xoff, y)
554
+ else:
555
+ path = svgwrite.path.Path(stroke="white", stroke_width="1")
556
+ dwg.add(path)
557
+ path.push("M", x + xoff, y)
558
+ return dwg.tostring()
interpolation.py CHANGED
@@ -20,11 +20,10 @@ def main(params):
20
  net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
21
 
22
  if not torch.cuda.is_available():
23
- try: # retrained model also contains loss in dict
24
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
25
  except:
26
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
27
-
28
 
29
  dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
30
 
@@ -34,7 +33,6 @@ def main(params):
34
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
35
  all_loaded_data.append(loaded_data)
36
 
37
-
38
  if params.output == "image":
39
 
40
  if params.interpolate == "writer":
@@ -78,6 +76,7 @@ def main(params):
78
  else:
79
  raise ValueError("Invalid output")
80
 
 
81
  if __name__ == '__main__':
82
  parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')
83
 
@@ -89,25 +88,25 @@ if __name__ == '__main__':
89
  parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])
90
 
91
  # PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
92
- # IF IMAGE - weights to use for a single sample of interpolation
93
  parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
94
- # IF VIDEO - the number of frames for each character/writer
95
  parser.add_argument('--frames_per_step', type=int, default=10)
96
 
97
  # PARAMS IF WRITER INTERPOLATION:
98
  parser.add_argument('--target_word', type=str, default="hello world")
99
  parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
100
-
101
  # PARAMS IF CHARACTER INTERPOLATION:
102
- # IF VIDEO OR BLEND
103
- parser.add_argument('--blend_chars', type=str, nargs="+", default = ["a", "b", "c", "d", "e"])
104
- # IF GRID
105
- parser.add_argument('--grid_chars', type=str, nargs="+", default= ["y", "s", "u", "n"])
106
  parser.add_argument('--grid_size', type=int, default=10)
107
 
108
  # PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
109
- parser.add_argument('--max_randomness', type=float, default=1)
110
- parser.add_argument('--scale_randomness', type=float, default=0.5)
111
  parser.add_argument('--num_random_samples', type=int, default=10)
112
 
113
  main(parser.parse_args())
 
20
  net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device)
21
 
22
  if not torch.cuda.is_available():
23
+ try: # retrained model also contains loss in dict
24
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"])
25
  except:
26
  net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu')))
 
27
 
28
  dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers')
29
 
 
33
  loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(params.num_samples)))
34
  all_loaded_data.append(loaded_data)
35
 
 
36
  if params.output == "image":
37
 
38
  if params.interpolate == "writer":
 
76
  else:
77
  raise ValueError("Invalid output")
78
 
79
+
80
  if __name__ == '__main__':
81
  parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.')
82
 
 
88
  parser.add_argument('--interpolate', type=str, default="randomness", choices=["writer", "character", "randomness"])
89
 
90
  # PARAMS FOR BOTH WRITER AND CHARACTER INTERPOLATION:
91
+ # IF IMAGE - weights to use for a single sample of interpolation
92
  parser.add_argument('--blend_weights', type=float, nargs="+", default=[0.5, 0.5])
93
+ # IF VIDEO - the number of frames for each character/writer
94
  parser.add_argument('--frames_per_step', type=int, default=10)
95
 
96
  # PARAMS IF WRITER INTERPOLATION:
97
  parser.add_argument('--target_word', type=str, default="hello world")
98
  parser.add_argument('--writer_ids', type=int, nargs="+", default=[80, 120])
99
+
100
  # PARAMS IF CHARACTER INTERPOLATION:
101
+ # IF VIDEO OR BLEND
102
+ parser.add_argument('--blend_chars', type=str, nargs="+", default=["a", "b", "c", "d", "e"])
103
+ # IF GRID
104
+ parser.add_argument('--grid_chars', type=str, nargs="+", default=["y", "s", "u", "n"])
105
  parser.add_argument('--grid_size', type=int, default=10)
106
 
107
  # PARAMS IF RANDOMNESS ITERPOLATION (--output will be ignored):
108
+ parser.add_argument('--max_randomness', type=float, default=1)
109
+ parser.add_argument('--scale_randomness', type=float, default=0.5)
110
  parser.add_argument('--num_random_samples', type=int, default=10)
111
 
112
  main(parser.parse_args())
output.svg ADDED
requirements.txt CHANGED
@@ -8,4 +8,5 @@ torch==1.11.0
8
  typing_extensions==4.1.1
9
  ffmpeg-python
10
  gradio
 
11
 
 
8
  typing_extensions==4.1.1
9
  ffmpeg-python
10
  gradio
11
+ svgwrite
12