Spaces:
Running
Running
Support concurrent per-call model choice (Before using a global model)
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import cv2
|
|
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
import argparse
|
@@ -28,10 +29,12 @@ actor.load_state_dict(torch.load(actor_path))
|
|
28 |
actor = actor.to(device).eval()
|
29 |
Decoder = Decoder.to(device).eval()
|
30 |
|
|
|
|
|
31 |
|
32 |
-
def decode(x, canvas): # b * (10 + 3)
|
33 |
x = x.view(-1, 10 + 3)
|
34 |
-
stroke = 1 -
|
35 |
stroke = stroke.view(-1, width, width, 1)
|
36 |
color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
|
37 |
stroke = stroke.permute(0, 3, 1, 2)
|
@@ -98,7 +101,9 @@ def save_img(res, imgid, origin_shape, output_name, divide=False):
|
|
98 |
|
99 |
|
100 |
|
101 |
-
def paint_img(img, max_step = 40):
|
|
|
|
|
102 |
max_step = int(max_step)
|
103 |
# imgid = 0
|
104 |
# output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
|
@@ -130,7 +135,7 @@ def paint_img(img, max_step = 40):
|
|
130 |
for i in range(max_step):
|
131 |
stepnum = T * i / max_step
|
132 |
actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
|
133 |
-
canvas, res = decode(actions, canvas)
|
134 |
for j in range(5):
|
135 |
# save_img(res[j], imgid)
|
136 |
# imgid += 1
|
@@ -152,7 +157,7 @@ def paint_img(img, max_step = 40):
|
|
152 |
for i in range(max_step):
|
153 |
stepnum = T * i / max_step
|
154 |
actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
|
155 |
-
canvas, res = decode(actions, canvas)
|
156 |
# print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
|
157 |
for j in range(5):
|
158 |
# save_img(res[j], imgid, True)
|
@@ -168,8 +173,8 @@ def paint_img(img, max_step = 40):
|
|
168 |
yield output
|
169 |
|
170 |
|
171 |
-
def
|
172 |
-
global Decoder, actor
|
173 |
if choice == "Default":
|
174 |
actor_path = 'ckpts/actor.pkl'
|
175 |
renderer_path = 'ckpts/renderer.pkl'
|
@@ -182,11 +187,19 @@ def change_model(choice: str):
|
|
182 |
else:
|
183 |
actor_path = 'ckpts/actor_notrans.pkl'
|
184 |
renderer_path = 'ckpts/bezierwotrans.pkl'
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
from typing import Generator
|
192 |
def wrapper(func):
|
@@ -233,9 +246,9 @@ with gr.Blocks() as demo:
|
|
233 |
output.render()
|
234 |
|
235 |
|
236 |
-
dropdown.select(
|
237 |
click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
|
238 |
-
.then(wrapper(paint_img), inputs=[translate_btn, input_image, step], outputs=output, trigger_mode = 'multiple')\
|
239 |
.then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
|
240 |
clr_btn.click(None, None, cancels=[click_event])
|
241 |
examples = gr.Examples(examples=examples,
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import cv2
|
4 |
+
from regex import D
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
import argparse
|
|
|
29 |
actor = actor.to(device).eval()
|
30 |
Decoder = Decoder.to(device).eval()
|
31 |
|
32 |
+
decoders = {"Default": Decoder}
|
33 |
+
actors = {"Default": actor}
|
34 |
|
35 |
+
def decode(x, canvas, decoder = Decoder): # b * (10 + 3)
|
36 |
x = x.view(-1, 10 + 3)
|
37 |
+
stroke = 1 - decoder(x[:, :10])
|
38 |
stroke = stroke.view(-1, width, width, 1)
|
39 |
color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
|
40 |
stroke = stroke.permute(0, 3, 1, 2)
|
|
|
101 |
|
102 |
|
103 |
|
104 |
+
def paint_img(img, max_step = 40, model_choices = "Default"):
|
105 |
+
Decoder = decoders[model_choices]
|
106 |
+
actor = actors[model_choices]
|
107 |
max_step = int(max_step)
|
108 |
# imgid = 0
|
109 |
# output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
|
|
|
135 |
for i in range(max_step):
|
136 |
stepnum = T * i / max_step
|
137 |
actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
|
138 |
+
canvas, res = decode(actions, canvas, Decoder)
|
139 |
for j in range(5):
|
140 |
# save_img(res[j], imgid)
|
141 |
# imgid += 1
|
|
|
157 |
for i in range(max_step):
|
158 |
stepnum = T * i / max_step
|
159 |
actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
|
160 |
+
canvas, res = decode(actions, canvas, Decoder)
|
161 |
# print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
|
162 |
for j in range(5):
|
163 |
# save_img(res[j], imgid, True)
|
|
|
173 |
yield output
|
174 |
|
175 |
|
176 |
+
def load_model_if_needed(choice: str):
|
177 |
+
# global Decoder, actor
|
178 |
if choice == "Default":
|
179 |
actor_path = 'ckpts/actor.pkl'
|
180 |
renderer_path = 'ckpts/renderer.pkl'
|
|
|
187 |
else:
|
188 |
actor_path = 'ckpts/actor_notrans.pkl'
|
189 |
renderer_path = 'ckpts/bezierwotrans.pkl'
|
190 |
+
if choice not in decoders:
|
191 |
+
Decoder = FCN()
|
192 |
+
Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
|
193 |
+
Decoder = Decoder.to(device).eval()
|
194 |
+
decoders[choice] = Decoder
|
195 |
+
if choice not in actors:
|
196 |
+
actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
|
197 |
+
actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
|
198 |
+
actor = actor.to(device).eval()
|
199 |
+
actors[choice] = actor
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
|
204 |
from typing import Generator
|
205 |
def wrapper(func):
|
|
|
246 |
output.render()
|
247 |
|
248 |
|
249 |
+
dropdown.select(load_model_if_needed, dropdown)
|
250 |
click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
|
251 |
+
.then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\
|
252 |
.then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
|
253 |
clr_btn.click(None, None, cancels=[click_event])
|
254 |
examples = gr.Examples(examples=examples,
|