ashawkey commited on
Commit
70375cc
1 Parent(s): 307e9e0

finally, everything works locally

Browse files
README.md CHANGED
@@ -1,15 +1,27 @@
1
- # MVDream-hf
2
 
3
- modified from https://github.com/KokeCacao/mvdream-hf.
4
 
5
- ### convert weights
 
 
6
 
7
- MVDream:
 
 
 
 
 
8
  ```bash
9
  # dependency
10
- pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
 
 
 
11
 
12
- # download original ckpt
 
 
13
  cd models
14
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
15
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
@@ -21,18 +33,31 @@ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4vi
21
 
22
  ImageDream:
23
  ```bash
24
- # download original ckpt
25
- wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv-local.pt
26
- wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv_local.yaml
 
 
27
 
28
  # convert
29
- python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv-local.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv_local.yaml --half --to_safetensors --test
30
  ```
31
 
32
- ### usage
33
-
34
- example:
35
- ```bash
36
- python run_mvdream.py "a cute owl"
37
- python run_imagedream.py data/anya_rgba.png
38
- ```
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MVDream-diffusers
2
 
3
+ A **unified** diffusers implementation of [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream).
4
 
5
+ We provide converted `fp16` weights on [huggingface](TODO).
6
+
7
+ ### Usage
8
 
9
+ ```bash
10
+ python run_mvdream.py "a cute owl"
11
+ python run_imagedream.py data/anya_rgba.png
12
+ ```
13
+
14
+ ### Install
15
  ```bash
16
  # dependency
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ ### Convert weights
21
 
22
+ MVDream:
23
+ ```bash
24
+ # download original ckpt (we only support the SD 2.1 version)
25
  cd models
26
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
27
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
 
33
 
34
  ImageDream:
35
  ```bash
36
+ # download original ckpt (we only support the pixel-controller version)
37
+ cd models
38
+ wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv.pt
39
+ wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv.yaml
40
+ cd ..
41
 
42
  # convert
43
+ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv.yaml --half --to_safetensors --test
44
  ```
45
 
46
+ ### Acknowledgement
47
+
48
+ * The original papers:
49
+ ```bibtex
50
+ @article{shi2023MVDream,
51
+ author = {Shi, Yichun and Wang, Peng and Ye, Jianglong and Mai, Long and Li, Kejie and Yang, Xiao},
52
+ title = {MVDream: Multi-view Diffusion for 3D Generation},
53
+ journal = {arXiv:2308.16512},
54
+ year = {2023},
55
+ }
56
+ @article{wang2023imagedream,
57
+ title={ImageDream: Image-Prompt Multi-view Diffusion for 3D Generation},
58
+ author={Wang, Peng and Shi, Yichun},
59
+ journal={arXiv preprint arXiv:2312.02201},
60
+ year={2023}
61
+ }
62
+ ```
63
+ * This codebase is modified from [mvdream-hf](https://github.com/KokeCacao/mvdream-hf).
convert_mvdream_to_diffusers.py CHANGED
@@ -568,7 +568,7 @@ if __name__ == "__main__":
568
  images = pipe(
569
  image=input_image,
570
  prompt="",
571
- negative_prompt="painting, bad quality, flat",
572
  output_type="pil",
573
  guidance_scale=5.0,
574
  num_inference_steps=50,
@@ -582,7 +582,7 @@ if __name__ == "__main__":
582
  images = loaded_pipe(
583
  image=input_image,
584
  prompt="",
585
- negative_prompt="painting, bad quality, flat",
586
  output_type="pil",
587
  guidance_scale=5.0,
588
  num_inference_steps=50,
 
568
  images = pipe(
569
  image=input_image,
570
  prompt="",
571
+ negative_prompt="",
572
  output_type="pil",
573
  guidance_scale=5.0,
574
  num_inference_steps=50,
 
582
  images = loaded_pipe(
583
  image=input_image,
584
  prompt="",
585
+ negative_prompt="",
586
  output_type="pil",
587
  guidance_scale=5.0,
588
  num_inference_steps=50,
mvdream/adaptor.py CHANGED
@@ -73,34 +73,6 @@ class PerceiverAttention(nn.Module):
73
  return self.to_out(out)
74
 
75
 
76
- class ImageProjModel(torch.nn.Module):
77
- """Projection Model"""
78
-
79
- def __init__(
80
- self,
81
- cross_attention_dim=1024,
82
- clip_embeddings_dim=1024,
83
- clip_extra_context_tokens=4,
84
- ):
85
- super().__init__()
86
- self.cross_attention_dim = cross_attention_dim
87
- self.clip_extra_context_tokens = clip_extra_context_tokens
88
-
89
- # from 1024 -> 4 * 1024
90
- self.proj = torch.nn.Linear(
91
- clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
92
- )
93
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
94
-
95
- def forward(self, image_embeds):
96
- embeds = image_embeds
97
- clip_extra_context_tokens = self.proj(embeds).reshape(
98
- -1, self.clip_extra_context_tokens, self.cross_attention_dim
99
- )
100
- clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
101
- return clip_extra_context_tokens
102
-
103
-
104
  class Resampler(nn.Module):
105
  def __init__(
106
  self,
 
73
  return self.to_out(out)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class Resampler(nn.Module):
77
  def __init__(
78
  self,
mvdream/attention.py CHANGED
@@ -88,7 +88,7 @@ class MemoryEfficientCrossAttention(nn.Module):
88
  context = default(context, x)
89
 
90
  if self.ip_dim > 0:
91
- # context dim [(b frame_num), (77 + img_token), 1024]
92
  token_len = context.shape[1]
93
  context_ip = context[:, -self.ip_dim :, :]
94
  k_ip = self.to_k_ip(context_ip)
@@ -212,9 +212,7 @@ class SpatialTransformer3D(nn.Module):
212
  self.in_channels = in_channels
213
 
214
  inner_dim = n_heads * d_head
215
- self.norm = nn.GroupNorm(
216
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
217
- )
218
  self.proj_in = nn.Linear(in_channels, inner_dim)
219
 
220
  self.transformer_blocks = nn.ModuleList(
 
88
  context = default(context, x)
89
 
90
  if self.ip_dim > 0:
91
+ # context [B, 77 + 16(ip), 1024]
92
  token_len = context.shape[1]
93
  context_ip = context[:, -self.ip_dim :, :]
94
  k_ip = self.to_k_ip(context_ip)
 
212
  self.in_channels = in_channels
213
 
214
  inner_dim = n_heads * d_head
215
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 
 
216
  self.proj_in = nn.Linear(in_channels, inner_dim)
217
 
218
  self.transformer_blocks = nn.ModuleList(
mvdream/models.py CHANGED
@@ -14,7 +14,7 @@ from .util import (
14
  timestep_embedding,
15
  )
16
  from .attention import SpatialTransformer3D
17
- from .adaptor import Resampler, ImageProjModel
18
 
19
  import kiui
20
 
@@ -266,15 +266,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
266
  num_heads_upsample=-1,
267
  use_scale_shift_norm=False,
268
  resblock_updown=False,
269
- transformer_depth=1, # custom transformer support
270
- context_dim=None, # custom transformer support
271
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
272
- disable_self_attentions=None,
273
  num_attention_blocks=None,
274
- disable_middle_self_attn=False,
275
  adm_in_channels=None,
276
  camera_dim=None,
277
- ip_dim=0,
278
  ip_weight=1.0,
279
  **kwargs,
280
  ):
@@ -604,7 +602,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
604
 
605
  # imagedream variant
606
  if self.ip_dim > 0:
607
- x[(num_frames - 1) :: num_frames, :, :, :] = ip_img
608
  ip_emb = self.image_embed(ip)
609
  context = torch.cat((context, ip_emb), 1)
610
 
 
14
  timestep_embedding,
15
  )
16
  from .attention import SpatialTransformer3D
17
+ from .adaptor import Resampler
18
 
19
  import kiui
20
 
 
266
  num_heads_upsample=-1,
267
  use_scale_shift_norm=False,
268
  resblock_updown=False,
269
+ transformer_depth=1,
270
+ context_dim=None,
271
+ n_embed=None,
 
272
  num_attention_blocks=None,
 
273
  adm_in_channels=None,
274
  camera_dim=None,
275
+ ip_dim=0, # imagedream uses ip_dim > 0
276
  ip_weight=1.0,
277
  **kwargs,
278
  ):
 
602
 
603
  # imagedream variant
604
  if self.ip_dim > 0:
605
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
606
  ip_emb = self.image_embed(ip)
607
  context = torch.cat((context, ip_emb), 1)
608
 
mvdream/pipeline_mvdream.py CHANGED
@@ -405,29 +405,27 @@ class MVDreamPipeline(DiffusionPipeline):
405
  def encode_image(self, image, device, num_images_per_prompt):
406
  dtype = next(self.image_encoder.parameters()).dtype
407
 
408
- image = (image * 255).astype(np.uint8)
 
 
409
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
410
-
411
  image = image.to(device=device, dtype=dtype)
412
 
413
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
414
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
415
 
416
- # imagedream directly use zero as uncond image embeddings
417
- uncond_image_enc_hidden_states = torch.zeros_like(image_enc_hidden_states)
418
-
419
- return uncond_image_enc_hidden_states, image_enc_hidden_states
420
 
421
  def encode_image_latents(self, image, device, num_images_per_prompt):
422
 
423
- image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2) # [1, 3, H, W]
424
- image = image.to(device=device)
425
- image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
426
  dtype = next(self.image_encoder.parameters()).dtype
 
 
 
 
427
  image = image.to(dtype=dtype)
428
 
429
  posterior = self.vae.encode(image).latent_dist
430
-
431
  latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
432
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
433
 
@@ -436,13 +434,13 @@ class MVDreamPipeline(DiffusionPipeline):
436
  @torch.no_grad()
437
  def __call__(
438
  self,
439
- prompt: str = "a car",
440
  image: Optional[np.ndarray] = None,
441
  height: int = 256,
442
  width: int = 256,
443
  num_inference_steps: int = 50,
444
  guidance_scale: float = 7.0,
445
- negative_prompt: str = "bad quality",
446
  num_images_per_prompt: int = 1,
447
  eta: float = 0.0,
448
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -454,7 +452,6 @@ class MVDreamPipeline(DiffusionPipeline):
454
  ):
455
  self.unet = self.unet.to(device=device)
456
  self.vae = self.vae.to(device=device)
457
-
458
  self.text_encoder = self.text_encoder.to(device=device)
459
 
460
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
@@ -466,10 +463,9 @@ class MVDreamPipeline(DiffusionPipeline):
466
  self.scheduler.set_timesteps(num_inference_steps, device=device)
467
  timesteps = self.scheduler.timesteps
468
 
469
- # imagedream variant (TODO: debug)
470
  if image is not None:
471
  assert isinstance(image, np.ndarray) and image.dtype == np.float32
472
-
473
  self.image_encoder = self.image_encoder.to(device=device)
474
  image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
475
  image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
@@ -496,7 +492,11 @@ class MVDreamPipeline(DiffusionPipeline):
496
  None,
497
  )
498
 
499
- camera = get_camera(num_frames, extra_view=(actual_num_frames != num_frames)).to(dtype=latents.dtype, device=device)
 
 
 
 
500
 
501
  # Prepare extra step kwargs.
502
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -508,10 +508,7 @@ class MVDreamPipeline(DiffusionPipeline):
508
  # expand the latents if we are doing classifier free guidance
509
  multiplier = 2 if do_classifier_free_guidance else 1
510
  latent_model_input = torch.cat([latents] * multiplier)
511
- latent_model_input = self.scheduler.scale_model_input(
512
- latent_model_input, t
513
- )
514
-
515
 
516
  unet_inputs = {
517
  'x': latent_model_input,
 
405
  def encode_image(self, image, device, num_images_per_prompt):
406
  dtype = next(self.image_encoder.parameters()).dtype
407
 
408
+ if image.dtype == np.float32:
409
+ image = (image * 255).astype(np.uint8)
410
+
411
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
 
412
  image = image.to(device=device, dtype=dtype)
413
 
414
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
415
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
416
 
417
+ return torch.zeros_like(image_embeds), image_embeds
 
 
 
418
 
419
  def encode_image_latents(self, image, device, num_images_per_prompt):
420
 
 
 
 
421
  dtype = next(self.image_encoder.parameters()).dtype
422
+
423
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
424
+ image = 2 * image - 1
425
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
426
  image = image.to(dtype=dtype)
427
 
428
  posterior = self.vae.encode(image).latent_dist
 
429
  latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
430
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
431
 
 
434
  @torch.no_grad()
435
  def __call__(
436
  self,
437
+ prompt: str = "",
438
  image: Optional[np.ndarray] = None,
439
  height: int = 256,
440
  width: int = 256,
441
  num_inference_steps: int = 50,
442
  guidance_scale: float = 7.0,
443
+ negative_prompt: str = "",
444
  num_images_per_prompt: int = 1,
445
  eta: float = 0.0,
446
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
452
  ):
453
  self.unet = self.unet.to(device=device)
454
  self.vae = self.vae.to(device=device)
 
455
  self.text_encoder = self.text_encoder.to(device=device)
456
 
457
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
 
463
  self.scheduler.set_timesteps(num_inference_steps, device=device)
464
  timesteps = self.scheduler.timesteps
465
 
466
+ # imagedream variant
467
  if image is not None:
468
  assert isinstance(image, np.ndarray) and image.dtype == np.float32
 
469
  self.image_encoder = self.image_encoder.to(device=device)
470
  image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
471
  image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
 
492
  None,
493
  )
494
 
495
+ if image is not None:
496
+ camera = get_camera(num_frames, elevation=5, extra_view=True).to(dtype=latents.dtype, device=device)
497
+ else:
498
+ camera = get_camera(num_frames, elevation=15, extra_view=False).to(dtype=latents.dtype, device=device)
499
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
500
 
501
  # Prepare extra step kwargs.
502
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
508
  # expand the latents if we are doing classifier free guidance
509
  multiplier = 2 if do_classifier_free_guidance else 1
510
  latent_model_input = torch.cat([latents] * multiplier)
511
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
512
 
513
  unet_inputs = {
514
  'x': latent_model_input,
requirements.lock.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ omegaconf == 2.3.0
2
+ diffusers == 0.23.1
3
+ safetensors == 0.4.1
4
+ huggingface_hub == 0.19.4
5
+ transformers == 4.35.2
6
+ accelerate == 0.25.0.dev0
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ omegaconf
2
+ diffusers
3
+ safetensors
4
+ huggingface_hub
5
+ transformers
6
+ accelerate
run_imagedream.py CHANGED
@@ -17,9 +17,9 @@ parser.add_argument("image", type=str, default='data/anya_rgba.png')
17
  parser.add_argument("--prompt", type=str, default="")
18
  args = parser.parse_args()
19
 
20
- while True:
21
  input_image = kiui.read_image(args.image, mode='float')
22
- image = pipe(args.prompt, input_image)
23
  grid = np.concatenate(
24
  [
25
  np.concatenate([image[0], image[2]], axis=0),
@@ -28,5 +28,4 @@ while True:
28
  axis=1,
29
  )
30
  # kiui.vis.plot_image(grid)
31
- kiui.write_image('test_imagedream.jpg', grid)
32
- break
 
17
  parser.add_argument("--prompt", type=str, default="")
18
  args = parser.parse_args()
19
 
20
+ for i in range(5):
21
  input_image = kiui.read_image(args.image, mode='float')
22
+ image = pipe(args.prompt, input_image, guidance_scale=5)
23
  grid = np.concatenate(
24
  [
25
  np.concatenate([image[0], image[2]], axis=0),
 
28
  axis=1,
29
  )
30
  # kiui.vis.plot_image(grid)
31
+ kiui.write_image(f'test_imagedream_{i}.jpg', grid)
 
run_mvdream.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  from mvdream.pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
- "./weights", # local weights
9
  # "ashawkey/mvdream-sd2.1-diffusers",
10
  torch_dtype=torch.float16
11
  )
@@ -16,7 +16,7 @@ parser = argparse.ArgumentParser(description="MVDream")
16
  parser.add_argument("prompt", type=str, default="a cute owl 3d model")
17
  args = parser.parse_args()
18
 
19
- while True:
20
  image = pipe(args.prompt)
21
  grid = np.concatenate(
22
  [
@@ -26,5 +26,4 @@ while True:
26
  axis=1,
27
  )
28
  # kiui.vis.plot_image(grid)
29
- kiui.write_image('test_mvdream.jpg', grid)
30
- break
 
5
  from mvdream.pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
+ "./weights_mvdream", # local weights
9
  # "ashawkey/mvdream-sd2.1-diffusers",
10
  torch_dtype=torch.float16
11
  )
 
16
  parser.add_argument("prompt", type=str, default="a cute owl 3d model")
17
  args = parser.parse_args()
18
 
19
+ for i in range(5):
20
  image = pipe(args.prompt)
21
  grid = np.concatenate(
22
  [
 
26
  axis=1,
27
  )
28
  # kiui.vis.plot_image(grid)
29
+ kiui.write_image(f'test_mvdream_{i}.jpg', grid)