Spaces:
Running
on
L40S
Running
on
L40S
update
Browse files- infer_api.py +49 -48
infer_api.py
CHANGED
@@ -758,53 +758,55 @@ class InferSlrmAPI:
|
|
758 |
|
759 |
return mesh_fpath
|
760 |
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
|
|
|
|
808 |
|
809 |
repo_id = "hyz317/StdGEN"
|
810 |
all_files = list_repo_files(repo_id, revision="main")
|
@@ -824,7 +826,6 @@ print(f"Using device!!!!!!!!!!!!: {infer_canonicalize_device}", file=sys.stderr)
|
|
824 |
infer_canonicalize_config_path = infer_canonicalize_config['config_path']
|
825 |
infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
|
826 |
|
827 |
-
# infer_canonicalize_setup(**infer_canonicalize_loaded_config)
|
828 |
|
829 |
def infer_canonicalize_setup(
|
830 |
validation: Dict,
|
|
|
758 |
|
759 |
return mesh_fpath
|
760 |
|
761 |
+
|
762 |
+
parser = argparse.ArgumentParser()
|
763 |
+
parser.add_argument("--seed", type=int, default=42)
|
764 |
+
parser.add_argument("--num_views", type=int, default=6)
|
765 |
+
parser.add_argument("--num_levels", type=int, default=3)
|
766 |
+
parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
|
767 |
+
parser.add_argument("--height", type=int, default=1024)
|
768 |
+
parser.add_argument("--width", type=int, default=576)
|
769 |
+
infer_multiview_cfg = parser.parse_args()
|
770 |
+
infer_multiview_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
771 |
+
infer_multiview_pipeline = load_multiview_pipeline(infer_multiview_cfg)
|
772 |
+
infer_multiview_results = {}
|
773 |
+
if torch.cuda.is_available():
|
774 |
+
infer_multiview_pipeline.to(device)
|
775 |
+
|
776 |
+
infer_multiview_image_transforms = [transforms.Resize(int(max(infer_multiview_cfg.height, infer_multiview_cfg.width))),
|
777 |
+
transforms.CenterCrop((infer_multiview_cfg.height, infer_multiview_cfg.width)),
|
778 |
+
transforms.ToTensor(),
|
779 |
+
transforms.Lambda(lambda x: x * 2. - 1),
|
780 |
+
]
|
781 |
+
infer_multiview_image_transforms = transforms.Compose(infer_multiview_image_transforms)
|
782 |
+
|
783 |
+
prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
|
784 |
+
infer_multiview_normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
|
785 |
+
infer_multiview_color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
|
786 |
+
infer_multiview_total_views = infer_multiview_cfg.num_views
|
787 |
+
|
788 |
+
|
789 |
+
@spaces.GPU
|
790 |
+
def process_im(self, im):
|
791 |
+
im = self.image_transforms(im)
|
792 |
+
return im
|
793 |
+
|
794 |
+
|
795 |
+
@spaces.GPU
|
796 |
+
def infer_multiview_gen(img, seed, num_levels):
|
797 |
+
set_seed(seed)
|
798 |
+
data = {}
|
799 |
+
|
800 |
+
cond_im_rgb = process_im(img)
|
801 |
+
cond_im_rgb = torch.stack([cond_im_rgb] * infer_multiview_total_views, dim=0)
|
802 |
+
data["image_cond_rgb"] = cond_im_rgb[None, ...]
|
803 |
+
data["normal_prompt_embeddings"] = infer_multiview_normal_text_embeds[None, ...]
|
804 |
+
data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
|
805 |
+
|
806 |
+
results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
|
807 |
+
# for k in results:
|
808 |
+
# self.results[k] = results[k]
|
809 |
+
return results
|
810 |
|
811 |
repo_id = "hyz317/StdGEN"
|
812 |
all_files = list_repo_files(repo_id, revision="main")
|
|
|
826 |
infer_canonicalize_config_path = infer_canonicalize_config['config_path']
|
827 |
infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
|
828 |
|
|
|
829 |
|
830 |
def infer_canonicalize_setup(
|
831 |
validation: Dict,
|