ashawkey commited on
Commit
cbfb133
·
1 Parent(s): 9167592
Files changed (4) hide show
  1. README.md +8 -1
  2. main.py +13 -3
  3. mvdream/models.py +13 -25
  4. mvdream/util.py +0 -196
README.md CHANGED
@@ -12,7 +12,14 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
12
  python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
13
  ```
14
 
15
- ### run pipeline
 
 
 
 
 
 
 
16
  ```python
17
  import torch
18
  import kiui
 
12
  python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
13
  ```
14
 
15
+ ### usage
16
+
17
+ example:
18
+ ```bash
19
+ python main.py "a cute owl"
20
+ ```
21
+
22
+ detailed usage:
23
  ```python
24
  import torch
25
  import kiui
main.py CHANGED
@@ -1,11 +1,21 @@
1
  import torch
2
  import kiui
 
 
3
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
4
 
5
  pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
6
  pipe = pipe.to("cuda")
7
 
8
- prompt = "a photo of an astronaut riding a horse on mars"
9
- image = pipe(prompt)
10
 
11
- kiui.vis.plot_image(image)
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import kiui
3
+ import numpy as np
4
+ import argparse
5
  from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
6
 
7
  pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
8
  pipe = pipe.to("cuda")
9
 
 
 
10
 
11
+ parser = argparse.ArgumentParser(description='MVDream')
12
+ parser.add_argument('prompt', type=str, default="a cute owl 3d model")
13
+ args = parser.parse_args()
14
+
15
+ while True:
16
+ image = pipe(args.prompt)
17
+ grid = np.concatenate([
18
+ np.concatenate([image[0], image[2]], axis=0),
19
+ np.concatenate([image[1], image[3]], axis=0),
20
+ ], axis=1)
21
+ kiui.vis.plot_image(grid)
mvdream/models.py CHANGED
@@ -10,10 +10,8 @@ from abc import abstractmethod
10
  from .util import (
11
  checkpoint,
12
  conv_nd,
13
- linear,
14
  avg_pool_nd,
15
  zero_module,
16
- normalization,
17
  timestep_embedding,
18
  )
19
  from .attention import SpatialTransformer, SpatialTransformer3D
@@ -56,7 +54,7 @@ class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
56
  adm_in_channels=None,
57
  camera_dim=None,):
58
  super().__init__()
59
- self.unet: MultiViewUNetModel = MultiViewUNetModel(
60
  image_size=image_size,
61
  in_channels=in_channels,
62
  model_channels=model_channels,
@@ -218,7 +216,7 @@ class ResBlock(TimestepBlock):
218
  self.use_scale_shift_norm = use_scale_shift_norm
219
 
220
  self.in_layers = nn.Sequential(
221
- normalization(channels),
222
  nn.SiLU(),
223
  conv_nd(dims, channels, self.out_channels, 3, padding=1),
224
  )
@@ -236,13 +234,13 @@ class ResBlock(TimestepBlock):
236
 
237
  self.emb_layers = nn.Sequential(
238
  nn.SiLU(),
239
- linear(
240
  emb_channels,
241
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
242
  ),
243
  )
244
  self.out_layers = nn.Sequential(
245
- normalization(self.out_channels),
246
  nn.SiLU(),
247
  nn.Dropout(p=dropout),
248
  zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
@@ -310,7 +308,7 @@ class AttentionBlock(nn.Module):
310
  assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
311
  self.num_heads = channels // num_head_channels
312
  self.use_checkpoint = use_checkpoint
313
- self.norm = normalization(channels)
314
  self.qkv = conv_nd(1, channels, channels * 3, 1)
315
  if use_new_attention_order:
316
  # split qkv before split heads
@@ -418,16 +416,6 @@ class QKVAttention(nn.Module):
418
  return count_flops_attn(model, _x, y)
419
 
420
 
421
- class Timestep(nn.Module):
422
-
423
- def __init__(self, dim):
424
- super().__init__()
425
- self.dim = dim
426
-
427
- def forward(self, t):
428
- return timestep_embedding(t, self.dim)
429
-
430
-
431
  class MultiViewUNetModel(nn.Module):
432
  """
433
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
@@ -545,17 +533,17 @@ class MultiViewUNetModel(nn.Module):
545
 
546
  time_embed_dim = model_channels * 4
547
  self.time_embed = nn.Sequential(
548
- linear(model_channels, time_embed_dim),
549
  nn.SiLU(),
550
- linear(time_embed_dim, time_embed_dim),
551
  )
552
 
553
  if camera_dim is not None:
554
  time_embed_dim = model_channels * 4
555
  self.camera_embed = nn.Sequential(
556
- linear(camera_dim, time_embed_dim),
557
  nn.SiLU(),
558
- linear(time_embed_dim, time_embed_dim),
559
  )
560
 
561
  if self.num_classes is not None:
@@ -567,9 +555,9 @@ class MultiViewUNetModel(nn.Module):
567
  elif self.num_classes == "sequential":
568
  assert adm_in_channels is not None
569
  self.label_emb = nn.Sequential(nn.Sequential(
570
- linear(adm_in_channels, time_embed_dim),
571
  nn.SiLU(),
572
- linear(time_embed_dim, time_embed_dim),
573
  ))
574
  else:
575
  raise ValueError()
@@ -722,13 +710,13 @@ class MultiViewUNetModel(nn.Module):
722
  self._feature_size += ch
723
 
724
  self.out = nn.Sequential(
725
- normalization(ch),
726
  nn.SiLU(),
727
  zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
728
  )
729
  if self.predict_codebook_ids:
730
  self.id_predictor = nn.Sequential(
731
- normalization(ch),
732
  conv_nd(dims, model_channels, n_embed, 1),
733
  #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
734
  )
 
10
  from .util import (
11
  checkpoint,
12
  conv_nd,
 
13
  avg_pool_nd,
14
  zero_module,
 
15
  timestep_embedding,
16
  )
17
  from .attention import SpatialTransformer, SpatialTransformer3D
 
54
  adm_in_channels=None,
55
  camera_dim=None,):
56
  super().__init__()
57
+ self.unet = MultiViewUNetModel(
58
  image_size=image_size,
59
  in_channels=in_channels,
60
  model_channels=model_channels,
 
216
  self.use_scale_shift_norm = use_scale_shift_norm
217
 
218
  self.in_layers = nn.Sequential(
219
+ nn.GroupNorm(32, channels),
220
  nn.SiLU(),
221
  conv_nd(dims, channels, self.out_channels, 3, padding=1),
222
  )
 
234
 
235
  self.emb_layers = nn.Sequential(
236
  nn.SiLU(),
237
+ nn.Linear(
238
  emb_channels,
239
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
240
  ),
241
  )
242
  self.out_layers = nn.Sequential(
243
+ nn.GroupNorm(32, self.out_channels),
244
  nn.SiLU(),
245
  nn.Dropout(p=dropout),
246
  zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
 
308
  assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
309
  self.num_heads = channels // num_head_channels
310
  self.use_checkpoint = use_checkpoint
311
+ self.norm = nn.GroupNorm(32, channels)
312
  self.qkv = conv_nd(1, channels, channels * 3, 1)
313
  if use_new_attention_order:
314
  # split qkv before split heads
 
416
  return count_flops_attn(model, _x, y)
417
 
418
 
 
 
 
 
 
 
 
 
 
 
419
  class MultiViewUNetModel(nn.Module):
420
  """
421
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
 
533
 
534
  time_embed_dim = model_channels * 4
535
  self.time_embed = nn.Sequential(
536
+ nn.Linear(model_channels, time_embed_dim),
537
  nn.SiLU(),
538
+ nn.Linear(time_embed_dim, time_embed_dim),
539
  )
540
 
541
  if camera_dim is not None:
542
  time_embed_dim = model_channels * 4
543
  self.camera_embed = nn.Sequential(
544
+ nn.Linear(camera_dim, time_embed_dim),
545
  nn.SiLU(),
546
+ nn.Linear(time_embed_dim, time_embed_dim),
547
  )
548
 
549
  if self.num_classes is not None:
 
555
  elif self.num_classes == "sequential":
556
  assert adm_in_channels is not None
557
  self.label_emb = nn.Sequential(nn.Sequential(
558
+ nn.Linear(adm_in_channels, time_embed_dim),
559
  nn.SiLU(),
560
+ nn.Linear(time_embed_dim, time_embed_dim),
561
  ))
562
  else:
563
  raise ValueError()
 
710
  self._feature_size += ch
711
 
712
  self.out = nn.Sequential(
713
+ nn.GroupNorm(32, ch),
714
  nn.SiLU(),
715
  zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
716
  )
717
  if self.predict_codebook_ids:
718
  self.id_predictor = nn.Sequential(
719
+ nn.GroupNorm(32, ch),
720
  conv_nd(dims, model_channels, n_embed, 1),
721
  #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
722
  )
mvdream/util.py CHANGED
@@ -10,136 +10,7 @@
10
  import math
11
  import torch
12
  import torch.nn as nn
13
- import numpy as np
14
- import importlib
15
  from einops import repeat
16
- from typing import Any
17
-
18
-
19
- def instantiate_from_config(config):
20
- if not "target" in config:
21
- if config == '__is_first_stage__':
22
- return None
23
- elif config == "__is_unconditional__":
24
- return None
25
- raise KeyError("Expected key `target` to instantiate.")
26
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
27
-
28
-
29
- def get_obj_from_str(string, reload=False):
30
- module, cls = string.rsplit(".", 1)
31
- if reload:
32
- module_imp = importlib.import_module(module)
33
- importlib.reload(module_imp)
34
- return getattr(importlib.import_module(module, package=None), cls)
35
-
36
-
37
- def make_beta_schedule(schedule,
38
- n_timestep,
39
- linear_start=1e-4,
40
- linear_end=2e-2,
41
- cosine_s=8e-3):
42
- if schedule == "linear":
43
- betas = (torch.linspace(linear_start**0.5,
44
- linear_end**0.5,
45
- n_timestep,
46
- dtype=torch.float64)**2)
47
-
48
- elif schedule == "cosine":
49
- timesteps = (
50
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
51
- cosine_s)
52
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
53
- alphas = torch.cos(alphas).pow(2)
54
- alphas = alphas / alphas[0]
55
- betas = 1 - alphas[1:] / alphas[:-1]
56
- betas = np.clip(betas, a_min=0, a_max=0.999)
57
-
58
- elif schedule == "sqrt_linear":
59
- betas = torch.linspace(linear_start,
60
- linear_end,
61
- n_timestep,
62
- dtype=torch.float64)
63
- elif schedule == "sqrt":
64
- betas = torch.linspace(linear_start,
65
- linear_end,
66
- n_timestep,
67
- dtype=torch.float64)**0.5
68
- else:
69
- raise ValueError(f"schedule '{schedule}' unknown.")
70
- return betas.numpy() # type: ignore
71
-
72
-
73
- def make_ddim_timesteps(ddim_discr_method,
74
- num_ddim_timesteps,
75
- num_ddpm_timesteps,
76
- verbose=True):
77
- if ddim_discr_method == 'uniform':
78
- c = num_ddpm_timesteps // num_ddim_timesteps
79
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
80
- elif ddim_discr_method == 'quad':
81
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
82
- num_ddim_timesteps))**2).astype(int)
83
- else:
84
- raise NotImplementedError(
85
- f'There is no ddim discretization method called "{ddim_discr_method}"'
86
- )
87
-
88
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
89
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
90
- steps_out = ddim_timesteps + 1
91
- if verbose:
92
- print(f'Selected timesteps for ddim sampler: {steps_out}')
93
- return steps_out
94
-
95
-
96
- def make_ddim_sampling_parameters(alphacums,
97
- ddim_timesteps,
98
- eta,
99
- verbose=True):
100
- # select alphas for computing the variance schedule
101
- alphas = alphacums[ddim_timesteps]
102
- alphas_prev = np.asarray([alphacums[0]] +
103
- alphacums[ddim_timesteps[:-1]].tolist())
104
-
105
- # according the the formula provided in https://arxiv.org/abs/2010.02502
106
- sigmas = eta * np.sqrt(
107
- (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
108
- if verbose:
109
- print(
110
- f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
111
- )
112
- print(
113
- f'For the chosen value of eta, which is {eta}, '
114
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
115
- )
116
- return sigmas, alphas, alphas_prev
117
-
118
-
119
- def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
120
- """
121
- Create a beta schedule that discretizes the given alpha_t_bar function,
122
- which defines the cumulative product of (1-beta) over time from t = [0,1].
123
- :param num_diffusion_timesteps: the number of betas to produce.
124
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
125
- produces the cumulative product of (1-beta) up to that
126
- part of the diffusion process.
127
- :param max_beta: the maximum beta to use; use values lower than 1 to
128
- prevent singularities.
129
- """
130
- betas = []
131
- for i in range(num_diffusion_timesteps):
132
- t1 = i / num_diffusion_timesteps
133
- t2 = (i + 1) / num_diffusion_timesteps
134
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
135
- return np.array(betas)
136
-
137
-
138
- def extract_into_tensor(a, t, x_shape):
139
- b, *_ = t.shape
140
- out = a.gather(-1, t)
141
- return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
142
-
143
 
144
  def checkpoint(func, inputs, params, flag):
145
  """
@@ -227,45 +98,6 @@ def zero_module(module):
227
  p.detach().zero_()
228
  return module
229
 
230
-
231
- def scale_module(module, scale):
232
- """
233
- Scale the parameters of a module and return it.
234
- """
235
- for p in module.parameters():
236
- p.detach().mul_(scale)
237
- return module
238
-
239
-
240
- def mean_flat(tensor):
241
- """
242
- Take the mean over all non-batch dimensions.
243
- """
244
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
245
-
246
-
247
- def normalization(channels):
248
- """
249
- Make a standard normalization layer.
250
- :param channels: number of input channels.
251
- :return: an nn.Module for normalization.
252
- """
253
- return GroupNorm32(32, channels)
254
-
255
-
256
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
257
- class SiLU(nn.Module):
258
-
259
- def forward(self, x):
260
- return x * torch.sigmoid(x)
261
-
262
-
263
- class GroupNorm32(nn.GroupNorm):
264
-
265
- def forward(self, x):
266
- return super().forward(x)
267
-
268
-
269
  def conv_nd(dims, *args, **kwargs):
270
  """
271
  Create a 1D, 2D, or 3D convolution module.
@@ -279,13 +111,6 @@ def conv_nd(dims, *args, **kwargs):
279
  raise ValueError(f"unsupported dimensions: {dims}")
280
 
281
 
282
- def linear(*args, **kwargs):
283
- """
284
- Create a linear module.
285
- """
286
- return nn.Linear(*args, **kwargs)
287
-
288
-
289
  def avg_pool_nd(dims, *args, **kwargs):
290
  """
291
  Create a 1D, 2D, or 3D average pooling module.
@@ -297,24 +122,3 @@ def avg_pool_nd(dims, *args, **kwargs):
297
  elif dims == 3:
298
  return nn.AvgPool3d(*args, **kwargs)
299
  raise ValueError(f"unsupported dimensions: {dims}")
300
-
301
-
302
- class HybridConditioner(nn.Module):
303
-
304
- def __init__(self, c_concat_config, c_crossattn_config):
305
- super().__init__()
306
- self.concat_conditioner: Any = instantiate_from_config(c_concat_config)
307
- self.crossattn_conditioner: Any = instantiate_from_config(
308
- c_crossattn_config)
309
-
310
- def forward(self, c_concat, c_crossattn):
311
- c_concat = self.concat_conditioner(c_concat)
312
- c_crossattn = self.crossattn_conditioner(c_crossattn)
313
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
314
-
315
-
316
- def noise_like(shape, device, repeat=False):
317
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
318
- shape[0], *((1, ) * (len(shape) - 1)))
319
- noise = lambda: torch.randn(shape, device=device)
320
- return repeat_noise() if repeat else noise()
 
10
  import math
11
  import torch
12
  import torch.nn as nn
 
 
13
  from einops import repeat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def checkpoint(func, inputs, params, flag):
16
  """
 
98
  p.detach().zero_()
99
  return module
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def conv_nd(dims, *args, **kwargs):
102
  """
103
  Create a 1D, 2D, or 3D convolution module.
 
111
  raise ValueError(f"unsupported dimensions: {dims}")
112
 
113
 
 
 
 
 
 
 
 
114
  def avg_pool_nd(dims, *args, **kwargs):
115
  """
116
  Create a 1D, 2D, or 3D average pooling module.
 
122
  elif dims == 3:
123
  return nn.AvgPool3d(*args, **kwargs)
124
  raise ValueError(f"unsupported dimensions: {dims}")