bhuvanmdev commited on
Commit
8d62889
1 Parent(s): 5180050

Update lgm/lgm.py

Browse files
Files changed (1) hide show
  1. lgm/lgm.py +16 -16
lgm/lgm.py CHANGED
@@ -237,7 +237,7 @@ class LGM(ModelMixin, ConfigMixin):
237
  self.rot_act = F.normalize
238
  self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
239
 
240
- def prepare_default_rays(self, device, elevation=0,views = 5):
241
  # cam_poses = np.stack(
242
  # [
243
  # orbit_camera(elevation, 0, radius=self.radius),
@@ -279,7 +279,7 @@ class LGM(ModelMixin, ConfigMixin):
279
  B, V, C, H, W = images.shape
280
  images = images.view(B * V, C, H, W)
281
 
282
- x = self.unet(images)
283
  x = self.conv(x)
284
 
285
  x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
@@ -518,21 +518,21 @@ class MVAttention(nn.Module):
518
  dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
519
  )
520
 
521
- def forward(self, x):
522
  BV, C, H, W = x.shape
523
- B = BV // self.num_frames
524
 
525
  res = x
526
  x = self.norm(x)
527
 
528
  x = (
529
- x.reshape(B, self.num_frames, C, H, W)
530
  .permute(0, 1, 3, 4, 2)
531
  .reshape(B, -1, C)
532
  )
533
  x = self.attn(x)
534
  x = (
535
- x.reshape(B, self.num_frames, H, W, C)
536
  .permute(0, 1, 4, 2, 3)
537
  .reshape(BV, C, H, W)
538
  )
@@ -634,12 +634,12 @@ class DownBlock(nn.Module):
634
  out_channels, out_channels, kernel_size=3, stride=2, padding=1
635
  )
636
 
637
- def forward(self, x):
638
  xs = []
639
  for attn, net in zip(self.attns, self.nets):
640
  x = net(x)
641
  if attn:
642
- x = attn(x)
643
  xs.append(x)
644
  if self.downsample:
645
  x = self.downsample(x)
@@ -672,11 +672,11 @@ class MidBlock(nn.Module):
672
  self.nets = nn.ModuleList(nets)
673
  self.attns = nn.ModuleList(attns)
674
 
675
- def forward(self, x):
676
  x = self.nets[0](x)
677
  for attn, net in zip(self.attns, self.nets[1:]):
678
  if attn:
679
- x = attn(x)
680
  x = net(x)
681
  return x
682
 
@@ -717,14 +717,14 @@ class UpBlock(nn.Module):
717
  out_channels, out_channels, kernel_size=3, stride=1, padding=1
718
  )
719
 
720
- def forward(self, x, xs):
721
  for attn, net in zip(self.attns, self.nets):
722
  res_x = xs[-1]
723
  xs = xs[:-1]
724
  x = torch.cat([x, res_x], dim=1)
725
  x = net(x)
726
  if attn:
727
- x = attn(x)
728
  if self.upsample:
729
  x = F.interpolate(x, scale_factor=2.0, mode="nearest")
730
  x = self.upsample(x)
@@ -798,17 +798,17 @@ class UNet(nn.Module):
798
  up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
799
  )
800
 
801
- def forward(self, x):
802
  x = self.conv_in(x)
803
  xss = [x]
804
  for block in self.down_blocks:
805
- x, xs = block(x)
806
  xss.extend(xs)
807
- x = self.mid_block(x)
808
  for block in self.up_blocks:
809
  xs = xss[-len(block.nets) :]
810
  xss = xss[: -len(block.nets)]
811
- x = block(x, xs)
812
  x = self.norm_out(x)
813
  x = F.silu(x)
814
  x = self.conv_out(x)
 
237
  self.rot_act = F.normalize
238
  self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
239
 
240
+ def prepare_default_rays(self, device, elevation=0,views = 4):
241
  # cam_poses = np.stack(
242
  # [
243
  # orbit_camera(elevation, 0, radius=self.radius),
 
279
  B, V, C, H, W = images.shape
280
  images = images.view(B * V, C, H, W)
281
 
282
+ x = self.unet(images,V) ###
283
  x = self.conv(x)
284
 
285
  x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
 
518
  dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
519
  )
520
 
521
+ def forward(self, x,views): ###
522
  BV, C, H, W = x.shape
523
+ B = BV // views###self.num_frames
524
 
525
  res = x
526
  x = self.norm(x)
527
 
528
  x = (
529
+ x.reshape(B, views, C, H, W)###
530
  .permute(0, 1, 3, 4, 2)
531
  .reshape(B, -1, C)
532
  )
533
  x = self.attn(x)
534
  x = (
535
+ x.reshape(B, views, H, W, C)###
536
  .permute(0, 1, 4, 2, 3)
537
  .reshape(BV, C, H, W)
538
  )
 
634
  out_channels, out_channels, kernel_size=3, stride=2, padding=1
635
  )
636
 
637
+ def forward(self, x,views):
638
  xs = []
639
  for attn, net in zip(self.attns, self.nets):
640
  x = net(x)
641
  if attn:
642
+ x = attn(x,views)
643
  xs.append(x)
644
  if self.downsample:
645
  x = self.downsample(x)
 
672
  self.nets = nn.ModuleList(nets)
673
  self.attns = nn.ModuleList(attns)
674
 
675
+ def forward(self, x, views):
676
  x = self.nets[0](x)
677
  for attn, net in zip(self.attns, self.nets[1:]):
678
  if attn:
679
+ x = attn(x,views)
680
  x = net(x)
681
  return x
682
 
 
717
  out_channels, out_channels, kernel_size=3, stride=1, padding=1
718
  )
719
 
720
+ def forward(self, x, xs,views): ###
721
  for attn, net in zip(self.attns, self.nets):
722
  res_x = xs[-1]
723
  xs = xs[:-1]
724
  x = torch.cat([x, res_x], dim=1)
725
  x = net(x)
726
  if attn:
727
+ x = attn(x,views) ##
728
  if self.upsample:
729
  x = F.interpolate(x, scale_factor=2.0, mode="nearest")
730
  x = self.upsample(x)
 
798
  up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
799
  )
800
 
801
+ def forward(self, x,views):###
802
  x = self.conv_in(x)
803
  xss = [x]
804
  for block in self.down_blocks:
805
+ x, xs = block(x,views)###
806
  xss.extend(xs)
807
+ x = self.mid_block(x,views) ###
808
  for block in self.up_blocks:
809
  xs = xss[-len(block.nets) :]
810
  xss = xss[: -len(block.nets)]
811
+ x = block(x, xs,views)
812
  x = self.norm_out(x)
813
  x = F.silu(x)
814
  x = self.conv_out(x)