JunhaoZhuang commited on
Commit
e7c2a21
·
verified ·
1 Parent(s): 3331e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -2
app.py CHANGED
@@ -179,9 +179,15 @@ global pipeline
179
  global MultiResNetModel
180
  global causal_dit
181
  global controlnet
 
182
 
183
  @spaces.GPU
184
  def load_ckpt():
 
 
 
 
 
185
  weight_dtype = torch.float16
186
 
187
  block_out_channels = [128, 128, 256, 512, 512]
@@ -292,11 +298,14 @@ def load_ckpt():
292
  print('loaded pipeline')
293
 
294
  load_ckpt()
295
-
296
- global cur_style
297
  cur_style = 'line + shadow'
298
  @spaces.GPU
299
  def change_ckpt(style):
 
 
 
 
 
300
  weight_dtype = torch.float16
301
 
302
  if style == 'line':
@@ -348,6 +357,11 @@ def process_multi_images(files):
348
 
349
  @spaces.GPU
350
  def extract_lines(image):
 
 
 
 
 
351
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
352
 
353
  rows = int(np.ceil(src.shape[0] / 16)) * 16
@@ -373,6 +387,11 @@ def extract_lines(image):
373
 
374
  @spaces.GPU
375
  def extract_line_image(query_image_, resolution):
 
 
 
 
 
376
  tar_width, tar_height = resolution
377
  query_image = query_image_.resize((tar_width, tar_height))
378
  query_image = query_image.convert('L').convert('RGB')
@@ -418,6 +437,11 @@ def extract_sketch_line_image(query_image_, input_style):
418
 
419
  @spaces.GPU(duration=120)
420
  def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
 
 
 
 
 
421
  if extracted_line is None:
422
  gr.Info("Please preprocess the image first")
423
  raise ValueError("Please preprocess the image first")
 
179
  global MultiResNetModel
180
  global causal_dit
181
  global controlnet
182
+ global cur_style
183
 
184
  @spaces.GPU
185
  def load_ckpt():
186
+ global pipeline
187
+ global MultiResNetModel
188
+ global causal_dit
189
+ global controlnet
190
+ global cur_style
191
  weight_dtype = torch.float16
192
 
193
  block_out_channels = [128, 128, 256, 512, 512]
 
298
  print('loaded pipeline')
299
 
300
  load_ckpt()
 
 
301
  cur_style = 'line + shadow'
302
  @spaces.GPU
303
  def change_ckpt(style):
304
+ global pipeline
305
+ global MultiResNetModel
306
+ global causal_dit
307
+ global controlnet
308
+ global cur_style
309
  weight_dtype = torch.float16
310
 
311
  if style == 'line':
 
357
 
358
  @spaces.GPU
359
  def extract_lines(image):
360
+ global pipeline
361
+ global MultiResNetModel
362
+ global causal_dit
363
+ global controlnet
364
+ global cur_style
365
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
366
 
367
  rows = int(np.ceil(src.shape[0] / 16)) * 16
 
387
 
388
  @spaces.GPU
389
  def extract_line_image(query_image_, resolution):
390
+ global pipeline
391
+ global MultiResNetModel
392
+ global causal_dit
393
+ global controlnet
394
+ global cur_style
395
  tar_width, tar_height = resolution
396
  query_image = query_image_.resize((tar_width, tar_height))
397
  query_image = query_image.convert('L').convert('RGB')
 
437
 
438
  @spaces.GPU(duration=120)
439
  def colorize_image(extracted_line, reference_images, resolution, seed, num_inference_steps, top_k, hint_mask=None, hint_color=None, query_image_origin=None, extracted_image_ori=None):
440
+ global pipeline
441
+ global MultiResNetModel
442
+ global causal_dit
443
+ global controlnet
444
+ global cur_style
445
  if extracted_line is None:
446
  gr.Info("Please preprocess the image first")
447
  raise ValueError("Please preprocess the image first")