YulianSa commited on
Commit
bf1c514
·
1 Parent(s): e431330
Files changed (2) hide show
  1. app.py +5 -4
  2. infer_api.py +9 -15
app.py CHANGED
@@ -71,7 +71,8 @@ def arbitrary_to_apose(image, seed):
71
  def apose_to_multiview(apose_img, seed):
72
  # convert image to PIL.Image
73
  apose_img = Image.fromarray(apose_img)
74
- return infer_api.genStage2(apose_img, seed, num_levels=1)[0]["images"]
 
75
 
76
  def multiview_to_mesh(images):
77
  mesh_files = infer_api.genStage3(images)
@@ -79,9 +80,9 @@ def multiview_to_mesh(images):
79
 
80
  def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed):
81
  apose_img = Image.fromarray(apose_img)
82
- infer_api.genStage2(apose_img, seed, num_levels=2)
83
- print(infer_api.multiview_infer.results.keys())
84
- refined = infer_api.genStage4([mesh1, mesh2, mesh3], infer_api.multiview_infer.results)
85
  return refined
86
 
87
  with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
 
71
  def apose_to_multiview(apose_img, seed):
72
  # convert image to PIL.Image
73
  apose_img = Image.fromarray(apose_img)
74
+ results, _ = infer_api.genStage2(apose_img, seed, num_levels=1)
75
+ return results[0]["images"]
76
 
77
  def multiview_to_mesh(images):
78
  mesh_files = infer_api.genStage3(images)
 
80
 
81
  def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed):
82
  apose_img = Image.fromarray(apose_img)
83
+ _, all_results = infer_api.genStage2(apose_img, seed, num_levels=2)
84
+ print(all_results.keys())
85
+ refined = infer_api.genStage4([mesh1, mesh2, mesh3], all_results)
86
  return refined
87
 
88
  with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
infer_api.py CHANGED
@@ -341,16 +341,6 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
341
  torch.cuda.empty_cache()
342
  return results
343
 
344
- @spaces.GPU
345
- def load_multiview_pipeline(cfg):
346
- pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
347
- cfg.pretrained_path,
348
- torch_dtype=torch.float16,)
349
- pipeline.unet.enable_xformers_memory_efficient_attention()
350
- if torch.cuda.is_available():
351
- pipeline.to(device)
352
- return pipeline
353
-
354
 
355
  class InferAPI:
356
  def __init__(self,
@@ -768,10 +758,13 @@ parser.add_argument("--height", type=int, default=1024)
768
  parser.add_argument("--width", type=int, default=576)
769
  infer_multiview_cfg = parser.parse_args()
770
  infer_multiview_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
771
- infer_multiview_pipeline = load_multiview_pipeline(infer_multiview_cfg)
772
- infer_multiview_results = {}
 
 
773
  if torch.cuda.is_available():
774
  infer_multiview_pipeline.to(device)
 
775
 
776
  infer_multiview_image_transforms = [transforms.Resize(int(max(infer_multiview_cfg.height, infer_multiview_cfg.width))),
777
  transforms.CenterCrop((infer_multiview_cfg.height, infer_multiview_cfg.width)),
@@ -791,6 +784,7 @@ def process_im(self, im):
791
  im = self.image_transforms(im)
792
  return im
793
 
 
794
 
795
  @spaces.GPU
796
  def infer_multiview_gen(img, seed, num_levels):
@@ -804,9 +798,9 @@ def infer_multiview_gen(img, seed, num_levels):
804
  data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
805
 
806
  results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
807
- # for k in results:
808
- # self.results[k] = results[k]
809
- return results
810
 
811
  repo_id = "hyz317/StdGEN"
812
  all_files = list_repo_files(repo_id, revision="main")
 
341
  torch.cuda.empty_cache()
342
  return results
343
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  class InferAPI:
346
  def __init__(self,
 
758
  parser.add_argument("--width", type=int, default=576)
759
  infer_multiview_cfg = parser.parse_args()
760
  infer_multiview_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
761
+ infer_multiview_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
762
+ infer_multiview_cfg.pretrained_path,
763
+ torch_dtype=torch.float16,)
764
+ infer_multiview_pipeline.unet.enable_xformers_memory_efficient_attention()
765
  if torch.cuda.is_available():
766
  infer_multiview_pipeline.to(device)
767
+ infer_multiview_results = {}
768
 
769
  infer_multiview_image_transforms = [transforms.Resize(int(max(infer_multiview_cfg.height, infer_multiview_cfg.width))),
770
  transforms.CenterCrop((infer_multiview_cfg.height, infer_multiview_cfg.width)),
 
784
  im = self.image_transforms(im)
785
  return im
786
 
787
+ all_results = {}
788
 
789
  @spaces.GPU
790
  def infer_multiview_gen(img, seed, num_levels):
 
798
  data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
799
 
800
  results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
801
+ for k in results:
802
+ all_results[k] = results[k]
803
+ return results, all_results
804
 
805
  repo_id = "hyz317/StdGEN"
806
  all_files = list_repo_files(repo_id, revision="main")