Neural_painting / app.py
minhnh's picture
Fix bug import wrong dep
e0b460b
import gradio as gr
import os
import cv2
import torch
import numpy as np
import argparse
import torch.nn as nn
import torch.nn.functional as F
import gc
from baseline.DRL.actor import *
from baseline.Renderer.stroke_gen import *
from baseline.Renderer.model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
width = 128
actor_path = 'ckpts/actor.pkl'
renderer_path = 'ckpts/renderer.pkl'
#
divide = 4
canvas_cnt = divide * divide
Decoder = FCN()
Decoder.load_state_dict(torch.load(renderer_path))
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
actor.load_state_dict(torch.load(actor_path))
actor = actor.to(device).eval()
Decoder = Decoder.to(device).eval()
decoders = {"Default": Decoder}
actors = {"Default": actor}
def decode(x, canvas, decoder = Decoder): # b * (10 + 3)
x = x.view(-1, 10 + 3)
stroke = 1 - decoder(x[:, :10])
stroke = stroke.view(-1, width, width, 1)
color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
stroke = stroke.permute(0, 3, 1, 2)
color_stroke = color_stroke.permute(0, 3, 1, 2)
stroke = stroke.view(-1, 5, 1, width, width)
color_stroke = color_stroke.view(-1, 5, 3, width, width)
res = []
for i in range(5):
canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
res.append(canvas)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return canvas, res
def small2large(x):
# (d * d, width, width) -> (d * width, d * width)
x = x.reshape(divide, divide, width, width, -1)
x = np.transpose(x, (0, 2, 1, 3, 4))
x = x.reshape(divide * width, divide * width, -1)
return x
def large2small(x):
# (d * width, d * width) -> (d * d, width, width)
x = x.reshape(divide, width, divide, width, 3)
x = np.transpose(x, (0, 2, 1, 3, 4))
x = x.reshape(canvas_cnt, width, width, 3)
return x
def smooth(img):
def smooth_pix(img, tx, ty):
if tx == divide * width - 1 or ty == divide * width - 1 or tx == 0 or ty == 0:
return img
img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9
return img
for p in range(divide):
for q in range(divide):
x = p * width
y = q * width
for k in range(width):
img = smooth_pix(img, x + k, y + width - 1)
if q != divide - 1:
img = smooth_pix(img, x + k, y + width)
for k in range(width):
img = smooth_pix(img, x + width - 1, y + k)
if p != divide - 1:
img = smooth_pix(img, x + width, y + k)
return img
def save_img(res, imgid, origin_shape, output_name, divide=False):
output = res.detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
if divide:
output = small2large(output)
output = smooth(output)
else:
output = output[0]
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
cv2.imwrite(output_name +"/" + str(imgid) + '.jpg', output)
def paint_img(img, max_step = 40, model_choices = "Default"):
Decoder = decoders[model_choices]
actor = actors[model_choices]
max_step = int(max_step)
# imgid = 0
# output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
# os.makedirs(output_name, exist_ok= True)
# img = cv2.imread(args.img, cv2.IMREAD_COLOR)
origin_shape = (img.shape[1], img.shape[0])
patch_img = cv2.resize(img, (width * divide, width * divide))
patch_img = large2small(patch_img)
patch_img = np.transpose(patch_img, (0, 3, 1, 2))
patch_img = torch.tensor(patch_img).to(device).float() / 255.
img = cv2.resize(img, (width, width))
img = img.reshape(1, width, width, 3)
img = np.transpose(img, (0, 3, 1, 2))
img = torch.tensor(img).to(device).float() / 255.
T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device)
coord = torch.zeros([1, 2, width, width])
for i in range(width):
for j in range(width):
coord[0, 0, i, j] = i / (width - 1.)
coord[0, 1, i, j] = j / (width - 1.)
coord = coord.to(device) # Coordconv
canvas = torch.zeros([1, 3, width, width]).to(device)
with torch.no_grad():
if divide != 1:
max_step = max_step // 2
for i in range(max_step):
stepnum = T * i / max_step
actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
canvas, res = decode(actions, canvas, Decoder)
for j in range(5):
# save_img(res[j], imgid)
# imgid += 1
output = res[j].detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
output = output[0]
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
yield output
if divide != 1:
canvas = canvas[0].detach().cpu().numpy()
canvas = np.transpose(canvas, (1, 2, 0))
canvas = cv2.resize(canvas, (width * divide, width * divide))
canvas = large2small(canvas)
canvas = np.transpose(canvas, (0, 3, 1, 2))
canvas = torch.tensor(canvas).to(device).float()
coord = coord.expand(canvas_cnt, 2, width, width)
T = T.expand(canvas_cnt, 1, width, width)
for i in range(max_step):
stepnum = T * i / max_step
actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
canvas, res = decode(actions, canvas, Decoder)
# print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
for j in range(5):
# save_img(res[j], imgid, True)
# imgid += 1
output = res[j].detach().cpu().numpy() # d * d, 3, width, width
output = np.transpose(output, (0, 2, 3, 1))
output = small2large(output)
output = smooth(output)
output = (output * 255).astype('uint8')
output = cv2.resize(output, origin_shape)
yield output
yield output
def load_model_if_needed(choice: str):
# global Decoder, actor
if choice == "Default":
actor_path = 'ckpts/actor.pkl'
renderer_path = 'ckpts/renderer.pkl'
elif choice == "Triangle":
actor_path = 'ckpts/actor_triangle.pkl'
renderer_path = 'ckpts/triangle.pkl'
elif choice == "Round":
actor_path = 'ckpts/actor_round.pkl'
renderer_path = 'ckpts/round.pkl'
else:
actor_path = 'ckpts/actor_notrans.pkl'
renderer_path = 'ckpts/bezierwotrans.pkl'
if choice not in decoders:
Decoder = FCN()
Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
Decoder = Decoder.to(device).eval()
decoders[choice] = Decoder
if choice not in actors:
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
actor = actor.to(device).eval()
actors[choice] = actor
from typing import Generator
def wrapper(func):
event:Generator = range(0)
def inner(*args, **kwargs):
nonlocal event
val = args[0]
if val == "Cancel":
args_ = tuple(x for i,x in enumerate(args) if i > 0)
event = func(*args_, **kwargs)
yield from event
else:
try:
event.close()
yield
except:
pass
return inner
examples = [
["image/chaoyue.png"],
["image/degang.png"],
["image/JayChou.png"],
["image/Leslie.png"],
["image/mayun.png"],
]
output = gr.Image(label="Painting Result")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input image")
with gr.Row():
step = gr.Slider(20, 100, value= 40, step = 1, label= 'Painting step')
with gr.Row():
dropdown = gr.Dropdown(['Default', 'Round', 'Triangle', 'Bezier wo trans'], value= 'Default', label= 'Stroke choice')
with gr.Row():
with gr.Column():
clr_btn = gr.ClearButton([input_image, output], variant= "stop")
with gr.Column():
translate_btn = gr.Button(value="Paint", variant="primary")
with gr.Column():
output.render()
dropdown.select(load_model_if_needed, dropdown)
click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
.then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\
.then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
clr_btn.click(None, None, cancels=[click_event])
examples = gr.Examples(examples=examples,
inputs=[input_image], cache_examples = False)
# demo = gr.Interface(fn=paint_img, inputs=gr.Image(), outputs="image", examples = examples)
demo.queue(default_concurrency_limit= 4)
demo.launch(server_name="0.0.0.0", )