YulianSa commited on
Commit
e431330
·
1 Parent(s): 79c1f1a
Files changed (1) hide show
  1. infer_api.py +49 -48
infer_api.py CHANGED
@@ -758,53 +758,55 @@ class InferSlrmAPI:
758
 
759
  return mesh_fpath
760
 
761
- class InferMultiviewAPI:
762
- def __init__(self, config):
763
- parser = argparse.ArgumentParser()
764
- parser.add_argument("--seed", type=int, default=42)
765
- parser.add_argument("--num_views", type=int, default=6)
766
- parser.add_argument("--num_levels", type=int, default=3)
767
- parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
768
- parser.add_argument("--height", type=int, default=1024)
769
- parser.add_argument("--width", type=int, default=576)
770
- self.cfg = parser.parse_args()
771
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
772
- self.pipeline = load_multiview_pipeline(self.cfg)
773
- self.results = {}
774
- if torch.cuda.is_available():
775
- self.pipeline.to(device)
776
-
777
- self.image_transforms = [transforms.Resize(int(max(self.cfg.height, self.cfg.width))),
778
- transforms.CenterCrop((self.cfg.height, self.cfg.width)),
779
- transforms.ToTensor(),
780
- transforms.Lambda(lambda x: x * 2. - 1),
781
- ]
782
- self.image_transforms = transforms.Compose(self.image_transforms)
783
-
784
- prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
785
- self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
786
- self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
787
- self.total_views = self.cfg.num_views
788
-
789
-
790
- def process_im(self, im):
791
- im = self.image_transforms(im)
792
- return im
793
-
794
- def gen(self, img, seed, num_levels):
795
- set_seed(seed)
796
- data = {}
797
-
798
- cond_im_rgb = self.process_im(img)
799
- cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0)
800
- data["image_cond_rgb"] = cond_im_rgb[None, ...]
801
- data["normal_prompt_embeddings"] = self.normal_text_embeds[None, ...]
802
- data["color_prompt_embeddings"] = self.color_text_embeds[None, ...]
803
-
804
- results = run_multiview_infer(data, self.pipeline, self.cfg, num_levels=num_levels)
805
- for k in results:
806
- self.results[k] = results[k]
807
- return results
 
 
808
 
809
  repo_id = "hyz317/StdGEN"
810
  all_files = list_repo_files(repo_id, revision="main")
@@ -824,7 +826,6 @@ print(f"Using device!!!!!!!!!!!!: {infer_canonicalize_device}", file=sys.stderr)
824
  infer_canonicalize_config_path = infer_canonicalize_config['config_path']
825
  infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
826
 
827
- # infer_canonicalize_setup(**infer_canonicalize_loaded_config)
828
 
829
  def infer_canonicalize_setup(
830
  validation: Dict,
 
758
 
759
  return mesh_fpath
760
 
761
+
762
+ parser = argparse.ArgumentParser()
763
+ parser.add_argument("--seed", type=int, default=42)
764
+ parser.add_argument("--num_views", type=int, default=6)
765
+ parser.add_argument("--num_levels", type=int, default=3)
766
+ parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
767
+ parser.add_argument("--height", type=int, default=1024)
768
+ parser.add_argument("--width", type=int, default=576)
769
+ infer_multiview_cfg = parser.parse_args()
770
+ infer_multiview_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
771
+ infer_multiview_pipeline = load_multiview_pipeline(infer_multiview_cfg)
772
+ infer_multiview_results = {}
773
+ if torch.cuda.is_available():
774
+ infer_multiview_pipeline.to(device)
775
+
776
+ infer_multiview_image_transforms = [transforms.Resize(int(max(infer_multiview_cfg.height, infer_multiview_cfg.width))),
777
+ transforms.CenterCrop((infer_multiview_cfg.height, infer_multiview_cfg.width)),
778
+ transforms.ToTensor(),
779
+ transforms.Lambda(lambda x: x * 2. - 1),
780
+ ]
781
+ infer_multiview_image_transforms = transforms.Compose(infer_multiview_image_transforms)
782
+
783
+ prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
784
+ infer_multiview_normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
785
+ infer_multiview_color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
786
+ infer_multiview_total_views = infer_multiview_cfg.num_views
787
+
788
+
789
+ @spaces.GPU
790
+ def process_im(self, im):
791
+ im = self.image_transforms(im)
792
+ return im
793
+
794
+
795
+ @spaces.GPU
796
+ def infer_multiview_gen(img, seed, num_levels):
797
+ set_seed(seed)
798
+ data = {}
799
+
800
+ cond_im_rgb = process_im(img)
801
+ cond_im_rgb = torch.stack([cond_im_rgb] * infer_multiview_total_views, dim=0)
802
+ data["image_cond_rgb"] = cond_im_rgb[None, ...]
803
+ data["normal_prompt_embeddings"] = infer_multiview_normal_text_embeds[None, ...]
804
+ data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
805
+
806
+ results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
807
+ # for k in results:
808
+ # self.results[k] = results[k]
809
+ return results
810
 
811
  repo_id = "hyz317/StdGEN"
812
  all_files = list_repo_files(repo_id, revision="main")
 
826
  infer_canonicalize_config_path = infer_canonicalize_config['config_path']
827
  infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
828
 
 
829
 
830
  def infer_canonicalize_setup(
831
  validation: Dict,