YulianSa commited on
Commit
911a293
·
1 Parent(s): 01c0065
Files changed (1) hide show
  1. infer_api.py +18 -0
infer_api.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from PIL import Image
2
  import glob
3
 
@@ -102,6 +103,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
  VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
103
 
104
 
 
105
  def set_seed(seed):
106
  random.seed(seed)
107
  np.random.seed(seed)
@@ -165,6 +167,7 @@ def process_image(image, totensor, width, height):
165
  return totensor(image)
166
 
167
 
 
168
  @torch.no_grad()
169
  def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
170
  text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
@@ -268,6 +271,7 @@ def save_image_numpy(ndarr):
268
  im = im.resize((1024, 1024), Image.LANCZOS)
269
  return im
270
 
 
271
  def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
272
  if cfg.seed is None:
273
  generator = None
@@ -333,6 +337,7 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
333
  return results
334
 
335
 
 
336
  def load_multiview_pipeline(cfg):
337
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
338
  cfg.pretrained_path,
@@ -450,6 +455,7 @@ def calc_horizontal_offset2(target_mask, source_img):
450
  return best_offset_value
451
 
452
 
 
453
  def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
454
  distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
455
  if normal_0 is not None and normal_1 is not None:
@@ -516,6 +522,7 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
516
 
517
 
518
  class InferRefineAPI:
 
519
  def __init__(self, config):
520
  self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
521
  self.generator = SamAutomaticMaskGenerator(
@@ -529,6 +536,7 @@ class InferRefineAPI:
529
  )
530
  self.outside_ratio = 0.20
531
 
 
532
  def refine(self, meshes, imgs):
533
  fixed_v, fixed_f, fixed_t = None, None, None
534
  flow_vert, flow_vector = None, None
@@ -680,6 +688,7 @@ class InferRefineAPI:
680
 
681
 
682
  class InferSlrmAPI:
 
683
  def __init__(self, config):
684
  self.config_path = config['config_path']
685
  self.config = OmegaConf.load(self.config_path)
@@ -694,6 +703,7 @@ class InferSlrmAPI:
694
  self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
695
  self.model = self.model.eval()
696
 
 
697
  def gen(self, imgs):
698
  imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
699
  imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
@@ -701,6 +711,7 @@ class InferSlrmAPI:
701
  mesh_glb_fpaths = self.make3d(imgs)
702
  return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
703
 
 
704
  def make3d(self, images):
705
  input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
706
 
@@ -724,6 +735,7 @@ class InferSlrmAPI:
724
 
725
  return mesh_glb_fpaths
726
 
 
727
  def make_mesh(self, mesh_fpath, planes, level=None):
728
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
729
  mesh_dirname = os.path.dirname(mesh_fpath)
@@ -751,6 +763,7 @@ class InferSlrmAPI:
751
 
752
 
753
  class InferMultiviewAPI:
 
754
  def __init__(self, config):
755
  parser = argparse.ArgumentParser()
756
  parser.add_argument("--seed", type=int, default=42)
@@ -784,6 +797,7 @@ class InferMultiviewAPI:
784
  return im
785
 
786
 
 
787
  def gen(self, img, seed, num_levels):
788
  set_seed(seed)
789
  data = {}
@@ -801,6 +815,7 @@ class InferMultiviewAPI:
801
 
802
 
803
  class InferCanonicalAPI:
 
804
  def __init__(self, config):
805
  self.config = config
806
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -810,6 +825,7 @@ class InferCanonicalAPI:
810
 
811
  self.setup(**self.loaded_config)
812
 
 
813
  def setup(self,
814
  validation: Dict,
815
  pretrained_model_path: str,
@@ -858,6 +874,7 @@ class InferCanonicalAPI:
858
 
859
  self.bkg_remover = BkgRemover()
860
 
 
861
  def canonicalize(self, image, seed):
862
  generator = torch.Generator(device=device).manual_seed(seed)
863
  return inference(
@@ -866,6 +883,7 @@ class InferCanonicalAPI:
866
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
867
  )
868
 
 
869
  def gen(self, img_input, seed=0):
870
  if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
871
  # convert to RGB
 
1
+ import spaces
2
  from PIL import Image
3
  import glob
4
 
 
103
  VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
104
 
105
 
106
+ @spaces.GPU
107
  def set_seed(seed):
108
  random.seed(seed)
109
  np.random.seed(seed)
 
167
  return totensor(image)
168
 
169
 
170
+ @spaces.GPU
171
  @torch.no_grad()
172
  def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
173
  text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
 
271
  im = im.resize((1024, 1024), Image.LANCZOS)
272
  return im
273
 
274
+ @spaces.GPU
275
  def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
276
  if cfg.seed is None:
277
  generator = None
 
337
  return results
338
 
339
 
340
+ @spaces.GPU
341
  def load_multiview_pipeline(cfg):
342
  pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
343
  cfg.pretrained_path,
 
455
  return best_offset_value
456
 
457
 
458
+ @spaces.GPU
459
  def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
460
  distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
461
  if normal_0 is not None and normal_1 is not None:
 
522
 
523
 
524
  class InferRefineAPI:
525
+ @spaces.GPU
526
  def __init__(self, config):
527
  self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
528
  self.generator = SamAutomaticMaskGenerator(
 
536
  )
537
  self.outside_ratio = 0.20
538
 
539
+ @spaces.GPU
540
  def refine(self, meshes, imgs):
541
  fixed_v, fixed_f, fixed_t = None, None, None
542
  flow_vert, flow_vector = None, None
 
688
 
689
 
690
  class InferSlrmAPI:
691
+ @spaces.GPU
692
  def __init__(self, config):
693
  self.config_path = config['config_path']
694
  self.config = OmegaConf.load(self.config_path)
 
703
  self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
704
  self.model = self.model.eval()
705
 
706
+ @spaces.GPU
707
  def gen(self, imgs):
708
  imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
709
  imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
 
711
  mesh_glb_fpaths = self.make3d(imgs)
712
  return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
713
 
714
+ @spaces.GPU
715
  def make3d(self, images):
716
  input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
717
 
 
735
 
736
  return mesh_glb_fpaths
737
 
738
+ @spaces.GPU
739
  def make_mesh(self, mesh_fpath, planes, level=None):
740
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
741
  mesh_dirname = os.path.dirname(mesh_fpath)
 
763
 
764
 
765
  class InferMultiviewAPI:
766
+ @spaces.GPU
767
  def __init__(self, config):
768
  parser = argparse.ArgumentParser()
769
  parser.add_argument("--seed", type=int, default=42)
 
797
  return im
798
 
799
 
800
+ @spaces.GPU
801
  def gen(self, img, seed, num_levels):
802
  set_seed(seed)
803
  data = {}
 
815
 
816
 
817
  class InferCanonicalAPI:
818
+ @spaces.GPU
819
  def __init__(self, config):
820
  self.config = config
821
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
825
 
826
  self.setup(**self.loaded_config)
827
 
828
+ @spaces.GPU
829
  def setup(self,
830
  validation: Dict,
831
  pretrained_model_path: str,
 
874
 
875
  self.bkg_remover = BkgRemover()
876
 
877
+ @spaces.GPU
878
  def canonicalize(self, image, seed):
879
  generator = torch.Generator(device=device).manual_seed(seed)
880
  return inference(
 
883
  use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
884
  )
885
 
886
+ @spaces.GPU
887
  def gen(self, img_input, seed=0):
888
  if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
889
  # convert to RGB