Spaces:
Build error
Build error
import os | |
import torch | |
import pickle | |
import argparse | |
import numpy as np | |
from helper import * | |
from PIL import Image | |
import torch.nn as nn | |
import torch.optim as optim | |
from config.GlobalVariables import * | |
from tensorboardX import SummaryWriter | |
from SynthesisNetwork import SynthesisNetwork | |
from DataLoader import DataLoader | |
def main(params): | |
np.random.seed(0) | |
torch.manual_seed(0) | |
device = 'cpu' | |
dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') | |
net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) | |
if not torch.cuda.is_available(): | |
try: # retrained model also contains loss in dict | |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))["model_state_dict"]) | |
except: | |
net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device('cpu'))) | |
[_, _, _, _, _, _, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term,all_segment_level_char, all_segment_level_char_length] = dl.next_batch(TYPE='TRAIN', uid=params.writer_id, tids=[i for i in range(params.num_samples)]) | |
batch_word_level_stroke_in = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_in] | |
batch_word_level_stroke_out = [torch.FloatTensor(a).to(device) for a in all_word_level_stroke_out] | |
batch_word_level_stroke_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_stroke_length] | |
batch_word_level_term = [torch.FloatTensor(a).to(device) for a in all_word_level_term] | |
batch_word_level_char = [torch.LongTensor(a).to(device) for a in all_word_level_char] | |
batch_word_level_char_length = [torch.LongTensor(a).to(device).unsqueeze(-1) for a in all_word_level_char_length] | |
batch_segment_level_stroke_in = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_in] | |
batch_segment_level_stroke_out = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_stroke_out] | |
batch_segment_level_stroke_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_stroke_length] | |
batch_segment_level_term = [[torch.FloatTensor(a).to(device) for a in b] for b in all_segment_level_term] | |
batch_segment_level_char = [[torch.LongTensor(a).to(device) for a in b] for b in all_segment_level_char] | |
batch_segment_level_char_length = [[torch.LongTensor(a).to(device).unsqueeze(-1) for a in b] for b in all_segment_level_char_length] | |
if params.generating_default == 1: | |
with torch.no_grad(): | |
commands_list = net.sample([ batch_word_level_stroke_in, batch_word_level_stroke_out, batch_word_level_stroke_length, batch_word_level_term, batch_word_level_char, batch_word_level_char_length, batch_segment_level_stroke_in, batch_segment_level_stroke_out, batch_segment_level_stroke_length, batch_segment_level_term, batch_segment_level_char, batch_segment_level_char_length]) | |
[t_commands, o_commands, a_commands, b_commands, c_commands, d_commands] = commands_list | |
dst = Image.new('RGB', (750, 640)) | |
dst.paste(draw_commands(t_commands), (0, 0)) | |
dst.paste(draw_commands(o_commands), (0, 160)) | |
dst.paste(draw_commands(a_commands), (0, 320)) | |
dst.paste(draw_commands(d_commands), (0, 480)) | |
dst.save(f'results/default.png') | |
with torch.no_grad(): | |
word_inf_state_out = net.inf_state_fc1(batch_word_level_stroke_out[0]) | |
word_inf_state_out = net.inf_state_relu(word_inf_state_out) | |
word_inf_state_out, _ = net.inf_state_lstm(word_inf_state_out) | |
user_word_level_char = batch_word_level_char[0] | |
user_word_level_term = batch_word_level_term[0] | |
original_Wc = [] | |
word_batch_id = 0 | |
curr_seq_len = batch_word_level_stroke_length[0][word_batch_id][0] | |
curr_char_len = batch_word_level_char_length[0][word_batch_id][0] | |
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(device) | |
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1) | |
split_ids = torch.nonzero(current_term)[:,0] | |
char_vector_1 = net.char_vec_fc_1(char_vector) | |
char_vector_1 = net.char_vec_relu_1(char_vector_1) | |
char_out_1 = char_vector_1.unsqueeze(0) | |
char_out_1, (c,h) = net.char_lstm_1(char_out_1) | |
char_out_1 = char_out_1.squeeze(0) | |
char_out_1 = net.char_vec_fc2_1(char_out_1) | |
char_matrix_1 = char_out_1.view([-1,1,256,256]) | |
char_matrix_1 = char_matrix_1.squeeze(1) | |
char_matrix_inv_1 = torch.inverse(char_matrix_1) | |
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len] | |
W_c = torch.stack([W_c_t[i] for i in split_ids]) | |
original_Wc.append(W_c) | |
W = torch.bmm(char_matrix_inv_1, W_c.unsqueeze(2)).squeeze(-1) | |
mean_global_W = torch.mean(W, 0) | |
def sample_word(target_word): | |
available_segments = {} | |
for sid, sentence in enumerate(all_segment_level_char[0]): | |
for wid, word in enumerate(sentence): | |
segment = ''.join([CHARACTERS[i] for i in word]) | |
split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid])) | |
if segment in available_segments: | |
available_segments[segment].append([all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]) | |
else: | |
available_segments[segment] = [[all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]] | |
index = 0 | |
all_W_c = [] | |
while index <= len(target_word): | |
available = False | |
for end_index in range(len(target_word), index, -1): | |
segment = target_word[index:end_index] | |
# print (segment) | |
if segment in available_segments: | |
# print (f'in dic - {segment}') | |
available = True | |
candidates = available_segments[segment] | |
segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))] | |
out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_out).to(device).unsqueeze(0)) | |
out = net.inf_state_relu(out) | |
seg_W_c, _ = net.inf_state_lstm(out) | |
tmp = [] | |
for id in split_ids[0]: | |
tmp.append(seg_W_c[0, id].squeeze()) | |
all_W_c.append(tmp) | |
index = end_index | |
if index == len(target_word): | |
break | |
if not available: | |
character = target_word[index] | |
# print (f'no dic - {character}') | |
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) | |
out = net.char_vec_fc_1(char_vector) | |
out = net.char_vec_relu_1(out) | |
out, _ = net.char_lstm_1(out.unsqueeze(0)) | |
out = out.squeeze(0) | |
out = net.char_vec_fc2_1(out) | |
char_matrix = out.view([-1, 256, 256]) | |
TYPE_A_WC = torch.bmm(char_matrix, mean_global_W.repeat(char_matrix.size(0), 1).unsqueeze(2)).squeeze() | |
all_W_c.append([TYPE_A_WC]) | |
index += 1 | |
all_commands = [] | |
current_id = 0 | |
while True: | |
word_Wc_rec_TYPE_D = [] | |
TYPE_D_REF = [] | |
cid = 0 | |
for segment_batch_id in range(len(all_W_c)): | |
if len(TYPE_D_REF) == 0: | |
for each_segment_Wc in all_W_c[segment_batch_id]: | |
if cid >= current_id: | |
word_Wc_rec_TYPE_D.append(each_segment_Wc) | |
cid += 1 | |
if len(word_Wc_rec_TYPE_D) > 0: | |
TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) | |
else: | |
for each_segment_Wc in all_W_c[segment_batch_id]: | |
magic_inp = torch.cat([torch.stack(TYPE_D_REF, 0), each_segment_Wc.unsqueeze(0)], 0) | |
magic_inp = magic_inp.unsqueeze(0) | |
TYPE_D_out, (c,h) = net.magic_lstm(magic_inp) | |
TYPE_D_out = TYPE_D_out.squeeze(0) | |
word_Wc_rec_TYPE_D.append(TYPE_D_out[-1]) | |
TYPE_D_REF.append(all_W_c[segment_batch_id][-1]) | |
WC_ = torch.stack(word_Wc_rec_TYPE_D) | |
tmp_commands, res = net.sample_from_w_fix(WC_, target_word) | |
current_id = current_id + res | |
if len(all_commands) == 0: | |
all_commands.append(tmp_commands) | |
else: | |
all_commands.append(tmp_commands[1:]) | |
if res < 0 or current_id >= len(target_word): | |
break | |
# tmp_commands = net.sample_from_w_fix(torch.stack(tmp_WC), _, target_word) | |
commands = [] | |
px, py = 0, 100 | |
for coms in all_commands: | |
for i, [dx, dy, t] in enumerate(coms): | |
x = px + dx * 5 | |
y = py + dy * 5 | |
commands.append([x,y,t]) | |
px, py = x, y | |
commands = np.asarray(commands) | |
commands[:, 0] -= np.min(commands[:, 0]) | |
return commands | |
def sample(target_sentence): | |
words = target_sentence.split(' ') | |
im = Image.fromarray(np.zeros([160, 750])) | |
dr = ImageDraw.Draw(im) | |
width = 50 | |
for word in words: | |
all_commands = sample_word(word) | |
for [x,y,t] in all_commands: | |
if t == 0: | |
dr.line((px+width, py, x+width, y), 255, 1) | |
px, py = x, y | |
width += np.max(all_commands[:, 0]) + 25 | |
# im.convert("RGB").save(f'results/{target_word}.png') | |
im.convert("RGB").save(f'results/hello.png') | |
def sample_word2(target_word): | |
available_segments = {} | |
for sid, sentence in enumerate(all_segment_level_char[0]): | |
for wid, word in enumerate(sentence): | |
segment = ''.join([CHARACTERS[i] for i in word]) | |
split_ids = np.asarray(np.nonzero(all_segment_level_term[0][sid][wid])) | |
if segment in available_segments: | |
available_segments[segment].append([all_segment_level_stroke_in[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]) | |
else: | |
available_segments[segment] = [[all_segment_level_stroke_in[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], all_segment_level_stroke_out[0][sid][wid][:all_segment_level_stroke_length[0][sid][wid]], split_ids]] | |
index = 0 | |
all_W_c = [] | |
all_commands = [] | |
know_chars = [] | |
while index <= len(target_word): | |
available = False | |
for end_index in range(len(target_word), index, -1): | |
segment = target_word[index:end_index] | |
if segment in available_segments: | |
# print (f'in dic - {segment}') | |
available = True | |
candidates = available_segments[segment] | |
segment_level_stroke_in, segment_level_stroke_out, split_ids = candidates[np.random.randint(len(candidates))] | |
out = net.inf_state_fc1(torch.FloatTensor(segment_level_stroke_in).to(device).unsqueeze(0)) | |
out = net.inf_state_relu(out) | |
seg_W_c, _ = net.inf_state_lstm(out) | |
tmp = [] | |
for id in split_ids[0]: | |
# print (id) | |
tmp.append(seg_W_c[0, id].squeeze()) | |
all_W_c.append(tmp) | |
index = end_index | |
for i in range(index, end_index): | |
know_chars.append(i) | |
commands = [] | |
px, py = 0, 100 | |
for i, [dx, dy, t] in enumerate(segment_level_stroke_out): | |
x = px + dx * 5 | |
y = py + dy * 5 | |
commands.append([x,y,t]) | |
px, py = x, y | |
commands = np.asarray(commands) | |
commands[:, 0] -= np.min(commands[:, 0]) | |
all_commands.append(commands) | |
if index == len(target_word): | |
break | |
if not available: | |
character = target_word[index] | |
# print (f'no dic - {character}') | |
char_vector = torch.eye(len(CHARACTERS))[CHARACTERS.index(character)].to(device).unsqueeze(0) | |
out = net.char_vec_fc_1(char_vector) | |
out = net.char_vec_relu_1(out) | |
out, _ = net.char_lstm_1(out.unsqueeze(0)) | |
out = out.squeeze(0) | |
out = net.char_vec_fc2_1(out) | |
char_matrix = out.view([-1, 256, 256]) | |
TYPE_A_WC = torch.bmm(char_matrix, mean_global_W.repeat(char_matrix.size(0), 1).unsqueeze(2)).squeeze().unsqueeze(0) | |
index += 1 | |
temp_commands = net.sample_from_w(TYPE_A_WC, character) | |
commands = [] | |
px, py = 0, 100 | |
for i, [dx, dy, t] in enumerate(temp_commands): | |
x = px + dx * 5 | |
y = py + dy * 5 | |
commands.append([x,y,t]) | |
px, py = x, y | |
commands = np.asarray(commands) | |
commands[:, 0] -= np.min(commands[:, 0]) | |
all_commands.append(commands) | |
return all_commands | |
def sample2(target_sentence): | |
words = target_sentence.split(' ') | |
im = Image.fromarray(np.zeros([160, 750])) | |
dr = ImageDraw.Draw(im) | |
width = 50 | |
for word in words: | |
all_commands = sample_word2(word) | |
for c in all_commands: | |
for [x,y,t] in c: | |
if t == 0: | |
dr.line((px+width, py, x+width, y), 255, 1) | |
px, py = x, y | |
width += np.max(c[:, 0]) + 5 | |
width += 25 | |
# im.convert("RGB").save(f'results/{target_word}.png') | |
im.convert("RGB").save(f'results/hello2.png') | |
while True: | |
target_word = input("Type a sentence to generate : ") | |
if len(target_word) > 0: | |
if params.direct_use == 0: | |
sample(target_word) | |
else: | |
sample2(target_word) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Arguments for generating samples with the handwriting synthesis model.') | |
parser.add_argument('--writer_id', type=int, default=80) | |
parser.add_argument('--num_samples', type=int, default=10) | |
parser.add_argument('--generating_default', type=int, default=0) | |
parser.add_argument('--direct_use', type=int, default=0) | |
main(parser.parse_args()) | |