Haiyu Wu commited on
Commit
2223153
·
1 Parent(s): 8408f97
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -10,6 +10,7 @@ from sixdrepnet.model import SixDRepNet
10
  import pixel_generator.vec2face.model_vec2face as model_vec2face
11
  MAX_SEED = np.iinfo(np.int32).max
12
  import torch
 
13
  from time import time
14
 
15
 
@@ -79,7 +80,7 @@ def initialize_models():
79
 
80
  return generator, id_model, pose_model, quality_model
81
 
82
-
83
  def image_generation(input_image, quality, use_target_pose, pose, dimension, progress=gr.Progress()):
84
  generator, id_model, pose_model, quality_model = initialize_models()
85
 
@@ -120,7 +121,7 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
120
 
121
  return generated_images
122
 
123
-
124
  def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose, progress=gr.Progress()):
125
  # Ensure all dimension numbers are within [0, 512)
126
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
 
10
  import pixel_generator.vec2face.model_vec2face as model_vec2face
11
  MAX_SEED = np.iinfo(np.int32).max
12
  import torch
13
+ import spaces
14
  from time import time
15
 
16
 
 
80
 
81
  return generator, id_model, pose_model, quality_model
82
 
83
+ @spaces.GPU
84
  def image_generation(input_image, quality, use_target_pose, pose, dimension, progress=gr.Progress()):
85
  generator, id_model, pose_model, quality_model = initialize_models()
86
 
 
121
 
122
  return generated_images
123
 
124
+ @spaces.GPU
125
  def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose, progress=gr.Progress()):
126
  # Ensure all dimension numbers are within [0, 512)
127
  num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]