Charbel Malo commited on
Commit
a41ef40
·
verified ·
1 Parent(s): 1118f2d

Update face_parsing/swap.py

Browse files
Files changed (1) hide show
  1. face_parsing/swap.py +7 -1
face_parsing/swap.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  import torchvision.transforms as transforms
5
  import cv2
6
  import numpy as np
 
7
 
8
  from .model import BiSeNet
9
 
@@ -61,7 +62,9 @@ class SoftErosion(nn.Module):
61
 
62
  device = "cpu"
63
 
64
- def init_parser(pth_path, mode="cpu"):
 
 
65
  global device
66
  device = mode
67
  n_classes = 19
@@ -75,6 +78,7 @@ def init_parser(pth_path, mode="cpu"):
75
  return net
76
 
77
 
 
78
  def image_to_parsing(img, net):
79
  img = cv2.resize(img, (512, 512))
80
  img = img[:,:,::-1]
@@ -99,6 +103,7 @@ def get_mask(parsing, classes):
99
  return res
100
 
101
 
 
102
  def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
103
  parsing = image_to_parsing(source, net)
104
 
@@ -125,6 +130,7 @@ def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,
125
 
126
  return result
127
 
 
128
  def mask_regions_to_list(values):
129
  out_ids = []
130
  for value in values:
 
4
  import torchvision.transforms as transforms
5
  import cv2
6
  import numpy as np
7
+ import spaces
8
 
9
  from .model import BiSeNet
10
 
 
62
 
63
  device = "cpu"
64
 
65
+
66
+ @spaces.GPU(enable_queue=True)
67
+ def init_parser(pth_path, mode="cuda"):
68
  global device
69
  device = mode
70
  n_classes = 19
 
78
  return net
79
 
80
 
81
+ @spaces.GPU(enable_queue=True)
82
  def image_to_parsing(img, net):
83
  img = cv2.resize(img, (512, 512))
84
  img = img[:,:,::-1]
 
103
  return res
104
 
105
 
106
+ @spaces.GPU(enable_queue=True)
107
  def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
108
  parsing = image_to_parsing(source, net)
109
 
 
130
 
131
  return result
132
 
133
+ @spaces.GPU(enable_queue=True)
134
  def mask_regions_to_list(values):
135
  out_ids = []
136
  for value in values: