Spaces:
Build error
Build error
import os | |
import re | |
from random import random | |
import torch | |
import pickle | |
import argparse | |
import numpy as np | |
from helper import * | |
from PIL import Image | |
import torch.nn as nn | |
import torch.optim as optim | |
from config.GlobalVariables import * | |
from tensorboardX import SummaryWriter | |
from SynthesisNetwork import SynthesisNetwork | |
from DataLoader import DataLoader | |
# import ffmpeg # for problems with ffmpeg uninstall ffmpeg and then install ffmpeg-python | |
L = 256 | |
def get_mean_global_W(net, loaded_data, device): | |
"""gets the mean global style vector for a given writer""" | |
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, | |
all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length] = loaded_data | |
batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in] | |
batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out] | |
batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length] | |
batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term] | |
batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char] | |
batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length] | |
batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in] | |
batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out] | |
batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length] | |
batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term] | |
batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char] | |
batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length] | |
with torch.no_grad(): | |
word_inf_state_out = net.inf_state_fc1(batch_word_level_stroke_out[0]) | |
word_inf_state_out = net.inf_state_relu(word_inf_state_out) | |
word_inf_state_out, _ = net.inf_state_lstm(word_inf_state_out) | |
user_word_level_char = batch_word_level_char[0] | |
user_word_level_term = batch_word_level_term[0] | |
original_Wc = [] | |
word_batch_id = 0 | |
curr_seq_len = batch_word_level_stroke_length[0][word_batch_id][0] | |
curr_char_len = batch_word_level_char_length[0][word_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(device) | |
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:, 0] | |
char_vector_1 = net.char_vec_fc_1(char_vector) | |
char_vector_1 = net.char_vec_relu_1(char_vector_1) | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c, h) = net.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = net.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1, 1, 256, 256]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
original_Wc.append(W_c) | |
W = torch.bmm(char_matrix_inv_1, W_c.unsqueeze(2)).squeeze(-1) | |
mean_global_W = torch.mean(W, 0) | |
return mean_global_W | |
def get_DSD(net, target_word, writer_mean_Ws, all_loaded_data, device): | |
""" | |
returns a style vector and character matrix for each character/segment in target_word | |
n is the number of writers | |
M is the number of characters in the target word | |
L is the latent vector size (in this case 256) | |
input: | |
- target_word, a string of length M to be converted to a DSD | |
- writer_mean_Ws, a list of n style vectors of size L | |
output: | |
- all_writer_Ws, a tensor of size n x M x L representing the style vectors for each writer and character | |
- all_writer_Cs, a tensor of size n x M x L x L representing the corresponding character matrix | |
""" | |
n = len(all_loaded_data) | |
M = len(target_word) | |
all_writer_Ws = torch.zeros(n, M, L) | |
all_writer_Cs = torch.zeros(n, M, L, L) | |
for i in range(n): | |
np.random.seed(0) | |
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, | |
all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length] = all_loaded_data[i] | |
available_segments = {} | |
for sid, sentence in enumerate(all_segment_level_char[0]): | |
for wid, word in enumerate(sentence): | |
segment = ''.join([CHARACTERS[i] for i in word]) | |
split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid])) | |
if segment in available_segments: | |
available_segments[segment].append([all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]) | |
else: | |
available_segments[segment] = [[all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]] | |
index = 0 | |
all_W = [] | |
all_C = [] | |
# while index <= len(target_word): | |
while index < len(target_word): | |
available = False | |
# Currently this just uses each character individually instead of the whole segment | |
# for end_index in range(len(target_word), index, -1): | |
# segment = target_word[index:end_index] | |
# print (segment) | |
segment = target_word[index] | |
if segment in available_segments: # method beta | |
# print(f'in dic - {segment}') | |
available = True | |
candidates = available_segments[segment] | |
segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))] | |
out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_out).to(device).unsqueeze(0)) | |
out = net.inf_state_relu(out) | |
seg_W_c, (h_n, _) = net.inf_state_lstm(out) | |
character = segment[0] # take the first character of the segment? | |
# get character matrix using same method as method beta | |
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) | |
out = net.char_vec_fc_1(char_vector) | |
out = net.char_vec_relu_1(out) | |
out, _ = net.char_lstm_1(out.unsqueeze(0)) | |
out = out.squeeze(0) | |
out = net.char_vec_fc2_1(out) | |
char_matrix = out.view([-1, 256, 256]) | |
inv_char_matrix = char_matrix.inverse() | |
id = split_ids[0][0] | |
W_c_vector = seg_W_c[0, id].squeeze() | |
# invert to get writer-independed DSD | |
W_vector = torch.bmm(inv_char_matrix, W_c_vector.repeat(inv_char_matrix.size(0), 1).unsqueeze(2)) | |
all_W.append(W_vector) | |
all_C.append(char_matrix) | |
index += 1 | |
if index == len(target_word): | |
break | |
if not available: # method alpha | |
character = target_word[index] | |
# print(f'no dic - {character}') | |
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) | |
out = net.char_vec_fc_1(char_vector) | |
out = net.char_vec_relu_1(out) | |
out, _ = net.char_lstm_1(out.unsqueeze(0)) | |
out = out.squeeze(0) | |
out = net.char_vec_fc2_1(out) | |
char_matrix = out.view([-1, 256, 256]) | |
W_vector = writer_mean_Ws[i].repeat(char_matrix.size(0), 1).unsqueeze(2) | |
# all_W.append([W_vector]) | |
all_W.append(W_vector) | |
all_C.append(char_matrix) | |
index += 1 | |
all_writer_Ws[i, :, :] = torch.stack(all_W).squeeze() | |
all_writer_Cs[i, :, :, :] = torch.stack(all_C).squeeze() | |
return all_writer_Ws, all_writer_Cs | |
def get_writer_blend_W_c(writer_weights, all_Ws, all_Cs): | |
""" | |
generates character-dependent style-dependent DSDs for each character/segement in target_word, | |
averaging together the styles of the handwritings using provided weights | |
n is the number of writers | |
M is the number of characters in the target word | |
L is the latent vector size (in this case 256) | |
input: | |
- writer_weights, a list of length n weights for each writer that sum to one | |
- all_writer_Ws, an n x M x L tensor representing each weiter's style vector for every character | |
- all_writer_Cs, an n x M x L x L tensor representing the style's correspodning character matrix | |
output: | |
- an M x 1 x L tensor of M scharacter-dependent style-dependent DSDs | |
""" | |
n, M, _ = all_Ws.shape | |
weights_tensor = torch.tensor(writer_weights).repeat_interleave(M * L).reshape(n, M, L) # repeat accross remaining dimensions | |
W_vectors = (weights_tensor * all_Ws).sum(axis=0).unsqueeze(-1) # take weighted sum accross writers axis | |
char_matrices = all_Cs[0, :, :, :] # character matrices are independent of writer | |
W_cs = torch.bmm(char_matrices, W_vectors) | |
return W_cs.reshape(M, 1, L) | |
def get_character_blend_W_c(character_weights, all_Ws, all_Cs): | |
""" | |
generates a single character-dependent style-dependent DSD, | |
averaging together the characters using provided weights | |
M is the number of characters to blend | |
L is the latent vector size (in this case 256) | |
input: | |
- character_weights, a list of length M weights for each character that sum to one | |
- all_Ws, a 1 x M x L tensor representing the wwiter's style vector for each character | |
- all_Cs, 1 x M x L x L tensor representing the style's correspodning character matrix | |
output: | |
- a 1 x 1 x L tensor representing the character-dependent style-dependent DSDs | |
""" | |
M = len(character_weights) | |
W_vector = all_Ws[0, 0, :].unsqueeze(-1) | |
weights_tensor = torch.tensor(character_weights).repeat_interleave(L * L).reshape(1, M, L, L) # repeat accross remaining dimensions | |
char_matrix = (weights_tensor * all_Cs).sum(axis=1).squeeze() # take weighted sum accross characters axis | |
W_c = char_matrix @ W_vector | |
return W_c.reshape(1, 1, L) | |
def get_commands(net, target_word, all_W_c): # seems like target_word is only used for length | |
"""converts character-dependent style-dependent DSDs to a list of commands for drawing""" | |
all_commands = [] | |
current_id = 0 | |
while True: | |
word_Wc_rec_TYPE_D = [] | |
TYPE_D_REF = [] | |
cid = 0 | |
for segment_batch_id in range(len(all_W_c)): | |
if len(TYPE_D_REF) == 0: | |
for each_segment_Wc in all_W_c[segment_batch_id]: | |
if cid >= current_id: | |
word_Wc_rec_TYPE_D.append(each_segment_Wc) | |
cid += 1 | |
if len(word_Wc_rec_TYPE_D) > 0: | |
TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) | |
else: | |
for each_segment_Wc in all_W_c[segment_batch_id]: | |
magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0) | |
magic_inp = magic_inp.unsqueeze(0) | |
TYPE_D_out, (c, h) = net.magic_lstm(magic_inp) | |
TYPE_D_out = TYPE_D_out.squeeze(0) | |
word_Wc_rec_TYPE_D.append(TYPE_D_out[-1]) | |
TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) | |
WC_ = torch.stack(word_Wc_rec_TYPE_D) | |
tmp_commands, res = net.sample_from_w_fix(WC_) | |
current_id += res | |
if len(all_commands) == 0: | |
all_commands.append(tmp_commands) | |
else: | |
all_commands.append(tmp_commands[1:]) | |
if res < 0 or current_id >= len(target_word): | |
break | |
commands = [] | |
px, py = 0, 100 | |
for coms in all_commands: | |
for i, [dx, dy, t] in enumerate(coms): | |
x = px + dx * 5 | |
y = py + dy * 5 | |
commands.append([x, y, t]) | |
px, py = x, y | |
commands = np.asarray(commands) | |
commands[:, 0] -= np.min(commands[:, 0]) | |
return commands | |
def mdn_video(target_word, num_samples, scale_sd, clamp_mdn, net, all_loaded_data, device): | |
''' | |
Method creating gif of mdn samples | |
num_samples: number of samples to be inputted | |
max_scale: the maximum value used to scale SD while sampling (increment is based on num samples) | |
''' | |
words = target_word.split(' ') | |
us_target_word = re.sub(r"\s+", '_', target_word) | |
os.makedirs(f"./results/{us_target_word}_mdn_samples", exist_ok=True) | |
for i in range(num_samples): | |
net.scale_sd = scale_sd | |
net.clamp_mdn = clamp_mdn | |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device) | |
word_Ws = [] | |
word_Cs = [] | |
for word in words: | |
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device) | |
word_Ws.append(writer_Ws) | |
word_Cs.append(writer_Cs) | |
im = draw_words(words, word_Ws, word_Cs, [1], net) | |
im.convert("RGB").save(f'results/{us_target_word}_mdn_samples/sample_{i}.png') | |
# Convert fromes to video using ffmpeg | |
photos = ffmpeg.input(f'results/{us_target_word}_mdn_samples/sample_*.png', pattern_type='glob', framerate=10) | |
videos = photos.output(f'results/{us_target_word}_video.mov', vcodec="libx264", pix_fmt="yuv420p") | |
videos.run(overwrite_output=True) | |
def sample_blended_writers(writer_weights, target_sentence, net, all_loaded_data, device="cpu"): | |
"""Generates an image of handwritten text based on target_sentence""" | |
words = target_sentence.split(' ') | |
writer_mean_Ws = [] | |
for loaded_data in all_loaded_data: | |
mean_global_W = get_mean_global_W(net, loaded_data, device) | |
writer_mean_Ws.append(mean_global_W) | |
word_Ws = [] | |
word_Cs = [] | |
for word in words: | |
writer_Ws, writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device) | |
word_Ws.append(writer_Ws) | |
word_Cs.append(writer_Cs) | |
return draw_words(words, word_Ws, word_Cs, writer_weights, net) | |
def sample_character_grid(letters, grid_size, net, all_loaded_data, device="cpu"): | |
"""Generates an image of handwritten text based on target_sentence""" | |
width = 60 | |
im = Image.fromarray(np.zeros([(grid_size + 1) * width, (grid_size + 1) * width])) | |
dr = ImageDraw.Draw(im) | |
M = len(letters) | |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device) | |
# all_Ws = torch.zeros(1, M, L) | |
all_Cs = torch.zeros(1, M, L, L) | |
for i in range(M): # get corners of grid | |
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device) | |
# all_Ws[:, i, :] = W_vector | |
all_Cs[:, i, :, :] = char_matrix | |
all_Ws = mean_global_W.reshape(1, 1, L) | |
for i in range(grid_size): | |
for j in range(grid_size): | |
wx = i / (grid_size - 1) | |
wy = j / (grid_size - 1) | |
character_weights = [(1 - wx) * (1 - wy), # top left is 1 at (0, 0) | |
wx * (1 - wy), # top right is 1 at (1, 0) | |
(1 - wx) * wy, # bottom left is 1 at (0, 1) | |
wx * wy] # bottom right is 1 at (1, 1) | |
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs) | |
all_commands = get_commands(net, letters[0], all_W_c) | |
offset_x = i * width | |
offset_y = j * width | |
for [x, y, t] in all_commands: | |
if t == 0: | |
dr.line(( | |
px + offset_x + width/2, | |
py + offset_y - width/2, # letters are shifted down for some reason | |
x + offset_x + width/2, | |
y + offset_y - width/2), 255, 1) | |
px, py = x, y | |
return im | |
def writer_interpolation_video(target_sentence, transition_time, net, all_loaded_data, device="cpu"): | |
""" | |
Generates a video of interpolating between each provided writer | |
""" | |
n = len(all_loaded_data) | |
os.makedirs(f"./results/{target_sentence}_blend_frames", exist_ok=True) | |
words = target_sentence.split(' ') | |
writer_mean_Ws = [] | |
for loaded_data in all_loaded_data: | |
mean_global_W = get_mean_global_W(net, loaded_data, device) | |
writer_mean_Ws.append(mean_global_W) | |
word_Ws = [] | |
word_Cs = [] | |
for word in words: | |
all_writer_Ws, all_writer_Cs = get_DSD(net, word, writer_mean_Ws, all_loaded_data, device) | |
word_Ws.append(all_writer_Ws) | |
word_Cs.append(all_writer_Cs) | |
for i in range(n - 1): | |
for j in range(transition_time): | |
completion = j/(transition_time) | |
individual_weights = [1 - completion, completion] | |
writer_weights = [0] * i + individual_weights + [0] * (n - 2 - i) | |
im = draw_words(words, word_Ws, word_Cs, writer_weights, net) | |
im.convert("RGB").save(f"./results/{target_sentence}_blend_frames/frame_{str(i * transition_time + j).zfill(3)}.png") | |
# Convert fromes to video using ffmpeg | |
photos = ffmpeg.input(f"./results/{target_sentence}_blend_frames/frame_*.png", pattern_type='glob', framerate=10) | |
videos = photos.output(f"results/{target_sentence}_blend_video.mov", vcodec="libx264", pix_fmt="yuv420p") | |
videos.run(overwrite_output=True) | |
def mdn_single_sample(target_word, scale_sd, clamp_mdn, net, all_loaded_data, device): | |
''' | |
Method creating gif of mdn samples | |
num_samples: number of samples to be inputted | |
max_scale: the maximum value used to scale SD while sampling (increment is based on num samples) | |
''' | |
words = target_word.split(' ') | |
net.scale_sd = scale_sd | |
net.clamp_mdn = clamp_mdn | |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device) | |
word_Ws = [] | |
word_Cs = [] | |
for word in words: | |
writer_Ws, writer_Cs = get_DSD(net, word, [mean_global_W], [all_loaded_data[0]], device) | |
word_Ws.append(writer_Ws) | |
word_Cs.append(writer_Cs) | |
return draw_words(words, word_Ws, word_Cs, [1], net) | |
def sample_blended_chars(character_weights, letters, net, all_loaded_data, device="cpu"): | |
"""Generates an image of handwritten text based on target_sentence""" | |
M = len(letters) | |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device) | |
all_Cs = torch.zeros(1, M, L, L) | |
for i in range(M): # get corners of grid | |
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device) | |
all_Cs[:, i, :, :] = char_matrix | |
all_Ws = mean_global_W.reshape(1, 1, L) | |
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs) | |
all_commands = get_commands(net, letters[0], all_W_c) | |
im = commands_to_image(all_commands, 100, 100, 30, 30) | |
return im | |
def char_interpolation_video(letters, transition_time, net, all_loaded_data, device="cpu"): | |
"""Generates an image of handwritten text based on target_sentence""" | |
os.makedirs(f"./results/{''.join(letters)}_frames", exist_ok=True) # make a folder for the frames | |
M = len(letters) | |
mean_global_W = get_mean_global_W(net, all_loaded_data[0], device) | |
all_Cs = torch.zeros(1, M, L, L) | |
for i in range(M): # get corners of grid | |
W_vector, char_matrix = get_DSD(net, letters[i], [mean_global_W], [all_loaded_data[0]], device) | |
all_Cs[:, i, :, :] = char_matrix | |
all_Ws = mean_global_W.reshape(1, 1, L) | |
for i in range(M - 1): | |
for j in range(transition_time): | |
completion = j / (transition_time - 1) | |
individual_weights = [1 - completion, completion] | |
character_weights = [0] * i + individual_weights + [0] * (M - 2 - i) | |
all_W_c = get_character_blend_W_c(character_weights, all_Ws, all_Cs) | |
all_commands = get_commands(net, letters[i], all_W_c) | |
im = commands_to_image(all_commands, 100, 100, 25, 25) | |
im.convert("RGB").save(f"results/{''.join(letters)}_frames/frames_{str(i * transition_time + j).zfill(3)}.png") | |
# Convert fromes to video using ffmpeg | |
photos = ffmpeg.input(f"results/{''.join(letters)}_frames/frames_*.png", pattern_type='glob', framerate=24) | |
videos = photos.output(f"results/{''.join(letters)}_video.mov", vcodec="libx264", pix_fmt="yuv420p") | |
videos.run(overwrite_output=True) | |
def draw_words(words, word_Ws, word_Cs, writer_weights, net): | |
im = Image.fromarray(np.zeros([160, 750])) | |
dr = ImageDraw.Draw(im) | |
width = 50 | |
for word, all_writer_Ws, all_writer_Cs in zip(words, word_Ws, word_Cs): | |
all_W_c = get_writer_blend_W_c(writer_weights, all_writer_Ws, all_writer_Cs) | |
all_commands = get_commands(net, word, all_W_c) | |
for [x, y, t] in all_commands: | |
if t == 0: | |
dr.line((px+width, py, x+width, y), 255, 1) | |
px, py = x, y | |
width += np.max(all_commands[:, 0]) + 25 | |
return im | |
def commands_to_image(commands, imW, imH, xoff, yoff): | |
im = Image.fromarray(np.zeros([imW, imH])) | |
dr = ImageDraw.Draw(im) | |
for [x, y, t] in commands: | |
if t == 0: | |
dr.line(( | |
px + xoff, | |
py - yoff, # letters are shifted down for some reason | |
x + xoff, | |
y - yoff), 255, 1) | |
px, py = x, y | |
return im | |