liuyuan-pal commited on
Commit
959adf1
·
1 Parent(s): d0f39be
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import os
8
  import fire
9
 
 
10
  from ldm.util import add_margin
11
 
12
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
@@ -41,10 +42,40 @@ def resize_inputs(image_input, crop_size):
41
  results = add_margin(ref_img_, size=256)
42
  return results
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def run_demo():
46
- device = f"cuda:0" if torch.cuda.is_available() else "cpu"
47
- models = None # init_model(device, os.path.join(code_dir, ckpt))
 
48
 
49
  # init sam model
50
  mask_predictor = None # sam_init(device_idx)
@@ -86,9 +117,15 @@ def run_demo():
86
 
87
  with gr.Column(scale=1):
88
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
89
- elevation_slider = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
 
 
 
 
90
  run_btn = gr.Button('Run Generation', variant='primary', interactive=False)
91
 
 
 
92
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
93
  image_block.change(fn=partial(mask_prediction, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
94
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
@@ -96,7 +133,8 @@ def run_demo():
96
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
97
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
98
 
99
- run_btn.click
 
100
 
101
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
102
 
 
7
  import os
8
  import fire
9
 
10
+ from generate import load_model
11
  from ldm.util import add_margin
12
 
13
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
 
42
  results = add_margin(ref_img_, size=256)
43
  return results
44
 
45
+ def generate(model, seed, batch_view_num, sample_num, cfg_scale, image_input, elevation_input):
46
+ torch.random.manual_seed(seed)
47
+ np.random.seed(seed)
48
+
49
+ # prepare data
50
+ image_input = np.asarray(image_input)
51
+ image_input = image_input.astype(np.float32) / 255.0
52
+ ref_mask = image_input[:, :, 3:]
53
+ image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
54
+ image_input = image_input[:, :, :3] * 2.0 - 1.0
55
+ image_input = torch.from_numpy(image_input.astype(np.float32))
56
+ elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
57
+ data = {"input_image": image_input, "input_elevation": elevation_input}
58
+ for k, v in data.items():
59
+ data[k] = v.unsqueeze(0).cuda()
60
+ data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
61
+
62
+ x_sample = model.sample(data, cfg_scale, batch_view_num)
63
+
64
+ B, N, _, H, W = x_sample.shape
65
+ x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
66
+ x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
67
+ x_sample = x_sample.astype(np.uint8)
68
+
69
+ results = []
70
+ for bi in range(B):
71
+ results.append(torch.concat([x_sample[bi,ni] for ni in range(N)], 1))
72
+ results = torch.concat(results, 0)
73
+ return Image.fromarray(results)
74
 
75
  def run_demo():
76
+ # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
77
+ # models = None # init_model(device, os.path.join(code_dir, ckpt))
78
+ model = load_model('configs/syncdreamer', 'ckpt/syncdreamer-pretrain.ckpt', strict=True)
79
 
80
  # init sam model
81
  mask_predictor = None # sam_init(device_idx)
 
117
 
118
  with gr.Column(scale=1):
119
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
120
+ elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
121
+ cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
122
+ # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
123
+ # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
124
+ seed = gr.Number(6033, label='Random seed', interactive=True)
125
  run_btn = gr.Button('Run Generation', variant='primary', interactive=False)
126
 
127
+ output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
128
+
129
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
130
  image_block.change(fn=partial(mask_prediction, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
131
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
 
133
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
134
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
135
 
136
+ run_btn.click(partial(generate, model, seed, 16, 1, cfg_scale, input_block, elevation), outputs=[output_block])\
137
+ .success(fn=partial(update_guide, _USER_GUIDE0), outputs=[guide_text], queue=False)
138
 
139
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
140