minhnh commited on
Commit
e9cd82f
·
1 Parent(s): 940c64c

Support concurrent per-call model choice (Before using a global model)

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import cv2
 
4
  import torch
5
  import numpy as np
6
  import argparse
@@ -28,10 +29,12 @@ actor.load_state_dict(torch.load(actor_path))
28
  actor = actor.to(device).eval()
29
  Decoder = Decoder.to(device).eval()
30
 
 
 
31
 
32
- def decode(x, canvas): # b * (10 + 3)
33
  x = x.view(-1, 10 + 3)
34
- stroke = 1 - Decoder(x[:, :10])
35
  stroke = stroke.view(-1, width, width, 1)
36
  color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
37
  stroke = stroke.permute(0, 3, 1, 2)
@@ -98,7 +101,9 @@ def save_img(res, imgid, origin_shape, output_name, divide=False):
98
 
99
 
100
 
101
- def paint_img(img, max_step = 40):
 
 
102
  max_step = int(max_step)
103
  # imgid = 0
104
  # output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
@@ -130,7 +135,7 @@ def paint_img(img, max_step = 40):
130
  for i in range(max_step):
131
  stepnum = T * i / max_step
132
  actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
133
- canvas, res = decode(actions, canvas)
134
  for j in range(5):
135
  # save_img(res[j], imgid)
136
  # imgid += 1
@@ -152,7 +157,7 @@ def paint_img(img, max_step = 40):
152
  for i in range(max_step):
153
  stepnum = T * i / max_step
154
  actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
155
- canvas, res = decode(actions, canvas)
156
  # print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
157
  for j in range(5):
158
  # save_img(res[j], imgid, True)
@@ -168,8 +173,8 @@ def paint_img(img, max_step = 40):
168
  yield output
169
 
170
 
171
- def change_model(choice: str):
172
- global Decoder, actor
173
  if choice == "Default":
174
  actor_path = 'ckpts/actor.pkl'
175
  renderer_path = 'ckpts/renderer.pkl'
@@ -182,11 +187,19 @@ def change_model(choice: str):
182
  else:
183
  actor_path = 'ckpts/actor_notrans.pkl'
184
  renderer_path = 'ckpts/bezierwotrans.pkl'
185
-
186
- Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
187
- actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
188
- actor = actor.to(device).eval()
189
- Decoder = Decoder.to(device).eval()
 
 
 
 
 
 
 
 
190
 
191
  from typing import Generator
192
  def wrapper(func):
@@ -233,9 +246,9 @@ with gr.Blocks() as demo:
233
  output.render()
234
 
235
 
236
- dropdown.select(change_model, dropdown)
237
  click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
238
- .then(wrapper(paint_img), inputs=[translate_btn, input_image, step], outputs=output, trigger_mode = 'multiple')\
239
  .then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
240
  clr_btn.click(None, None, cancels=[click_event])
241
  examples = gr.Examples(examples=examples,
 
1
  import gradio as gr
2
  import os
3
  import cv2
4
+ from regex import D
5
  import torch
6
  import numpy as np
7
  import argparse
 
29
  actor = actor.to(device).eval()
30
  Decoder = Decoder.to(device).eval()
31
 
32
+ decoders = {"Default": Decoder}
33
+ actors = {"Default": actor}
34
 
35
+ def decode(x, canvas, decoder = Decoder): # b * (10 + 3)
36
  x = x.view(-1, 10 + 3)
37
+ stroke = 1 - decoder(x[:, :10])
38
  stroke = stroke.view(-1, width, width, 1)
39
  color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
40
  stroke = stroke.permute(0, 3, 1, 2)
 
101
 
102
 
103
 
104
+ def paint_img(img, max_step = 40, model_choices = "Default"):
105
+ Decoder = decoders[model_choices]
106
+ actor = actors[model_choices]
107
  max_step = int(max_step)
108
  # imgid = 0
109
  # output_name = os.path.join('output', str(len(os.listdir('output'))) if os.path.exists('output') else '0')
 
135
  for i in range(max_step):
136
  stepnum = T * i / max_step
137
  actions = actor(torch.cat([canvas, img, stepnum, coord], 1))
138
+ canvas, res = decode(actions, canvas, Decoder)
139
  for j in range(5):
140
  # save_img(res[j], imgid)
141
  # imgid += 1
 
157
  for i in range(max_step):
158
  stepnum = T * i / max_step
159
  actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1))
160
+ canvas, res = decode(actions, canvas, Decoder)
161
  # print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean()))
162
  for j in range(5):
163
  # save_img(res[j], imgid, True)
 
173
  yield output
174
 
175
 
176
+ def load_model_if_needed(choice: str):
177
+ # global Decoder, actor
178
  if choice == "Default":
179
  actor_path = 'ckpts/actor.pkl'
180
  renderer_path = 'ckpts/renderer.pkl'
 
187
  else:
188
  actor_path = 'ckpts/actor_notrans.pkl'
189
  renderer_path = 'ckpts/bezierwotrans.pkl'
190
+ if choice not in decoders:
191
+ Decoder = FCN()
192
+ Decoder.load_state_dict(torch.load(renderer_path, map_location= "cpu"))
193
+ Decoder = Decoder.to(device).eval()
194
+ decoders[choice] = Decoder
195
+ if choice not in actors:
196
+ actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13
197
+ actor.load_state_dict(torch.load(actor_path, map_location= "cpu"))
198
+ actor = actor.to(device).eval()
199
+ actors[choice] = actor
200
+
201
+
202
+
203
 
204
  from typing import Generator
205
  def wrapper(func):
 
246
  output.render()
247
 
248
 
249
+ dropdown.select(load_model_if_needed, dropdown)
250
  click_event = translate_btn.click(lambda x: gr.Button(value="Cancel", variant="stop") if x == "Paint" else gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)\
251
+ .then(wrapper(paint_img), inputs=[translate_btn, input_image, step, dropdown], outputs=output, trigger_mode = 'multiple')\
252
  .then(lambda x: gr.Button(value="Paint", variant="primary"), translate_btn, translate_btn)
253
  clr_btn.click(None, None, cancels=[click_event])
254
  examples = gr.Examples(examples=examples,