bhuvanmdev
commited on
Commit
•
8d62889
1
Parent(s):
5180050
Update lgm/lgm.py
Browse files- lgm/lgm.py +16 -16
lgm/lgm.py
CHANGED
@@ -237,7 +237,7 @@ class LGM(ModelMixin, ConfigMixin):
|
|
237 |
self.rot_act = F.normalize
|
238 |
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
|
239 |
|
240 |
-
def prepare_default_rays(self, device, elevation=0,views =
|
241 |
# cam_poses = np.stack(
|
242 |
# [
|
243 |
# orbit_camera(elevation, 0, radius=self.radius),
|
@@ -279,7 +279,7 @@ class LGM(ModelMixin, ConfigMixin):
|
|
279 |
B, V, C, H, W = images.shape
|
280 |
images = images.view(B * V, C, H, W)
|
281 |
|
282 |
-
x = self.unet(images)
|
283 |
x = self.conv(x)
|
284 |
|
285 |
x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
|
@@ -518,21 +518,21 @@ class MVAttention(nn.Module):
|
|
518 |
dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
|
519 |
)
|
520 |
|
521 |
-
def forward(self, x):
|
522 |
BV, C, H, W = x.shape
|
523 |
-
B = BV // self.num_frames
|
524 |
|
525 |
res = x
|
526 |
x = self.norm(x)
|
527 |
|
528 |
x = (
|
529 |
-
x.reshape(B,
|
530 |
.permute(0, 1, 3, 4, 2)
|
531 |
.reshape(B, -1, C)
|
532 |
)
|
533 |
x = self.attn(x)
|
534 |
x = (
|
535 |
-
x.reshape(B,
|
536 |
.permute(0, 1, 4, 2, 3)
|
537 |
.reshape(BV, C, H, W)
|
538 |
)
|
@@ -634,12 +634,12 @@ class DownBlock(nn.Module):
|
|
634 |
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
635 |
)
|
636 |
|
637 |
-
def forward(self, x):
|
638 |
xs = []
|
639 |
for attn, net in zip(self.attns, self.nets):
|
640 |
x = net(x)
|
641 |
if attn:
|
642 |
-
x = attn(x)
|
643 |
xs.append(x)
|
644 |
if self.downsample:
|
645 |
x = self.downsample(x)
|
@@ -672,11 +672,11 @@ class MidBlock(nn.Module):
|
|
672 |
self.nets = nn.ModuleList(nets)
|
673 |
self.attns = nn.ModuleList(attns)
|
674 |
|
675 |
-
def forward(self, x):
|
676 |
x = self.nets[0](x)
|
677 |
for attn, net in zip(self.attns, self.nets[1:]):
|
678 |
if attn:
|
679 |
-
x = attn(x)
|
680 |
x = net(x)
|
681 |
return x
|
682 |
|
@@ -717,14 +717,14 @@ class UpBlock(nn.Module):
|
|
717 |
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
718 |
)
|
719 |
|
720 |
-
def forward(self, x, xs):
|
721 |
for attn, net in zip(self.attns, self.nets):
|
722 |
res_x = xs[-1]
|
723 |
xs = xs[:-1]
|
724 |
x = torch.cat([x, res_x], dim=1)
|
725 |
x = net(x)
|
726 |
if attn:
|
727 |
-
x = attn(x)
|
728 |
if self.upsample:
|
729 |
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
730 |
x = self.upsample(x)
|
@@ -798,17 +798,17 @@ class UNet(nn.Module):
|
|
798 |
up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
|
799 |
)
|
800 |
|
801 |
-
def forward(self, x)
|
802 |
x = self.conv_in(x)
|
803 |
xss = [x]
|
804 |
for block in self.down_blocks:
|
805 |
-
x, xs = block(x)
|
806 |
xss.extend(xs)
|
807 |
-
x = self.mid_block(x)
|
808 |
for block in self.up_blocks:
|
809 |
xs = xss[-len(block.nets) :]
|
810 |
xss = xss[: -len(block.nets)]
|
811 |
-
x = block(x, xs)
|
812 |
x = self.norm_out(x)
|
813 |
x = F.silu(x)
|
814 |
x = self.conv_out(x)
|
|
|
237 |
self.rot_act = F.normalize
|
238 |
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
|
239 |
|
240 |
+
def prepare_default_rays(self, device, elevation=0,views = 4):
|
241 |
# cam_poses = np.stack(
|
242 |
# [
|
243 |
# orbit_camera(elevation, 0, radius=self.radius),
|
|
|
279 |
B, V, C, H, W = images.shape
|
280 |
images = images.view(B * V, C, H, W)
|
281 |
|
282 |
+
x = self.unet(images,V) ###
|
283 |
x = self.conv(x)
|
284 |
|
285 |
x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
|
|
|
518 |
dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
|
519 |
)
|
520 |
|
521 |
+
def forward(self, x,views): ###
|
522 |
BV, C, H, W = x.shape
|
523 |
+
B = BV // views###self.num_frames
|
524 |
|
525 |
res = x
|
526 |
x = self.norm(x)
|
527 |
|
528 |
x = (
|
529 |
+
x.reshape(B, views, C, H, W)###
|
530 |
.permute(0, 1, 3, 4, 2)
|
531 |
.reshape(B, -1, C)
|
532 |
)
|
533 |
x = self.attn(x)
|
534 |
x = (
|
535 |
+
x.reshape(B, views, H, W, C)###
|
536 |
.permute(0, 1, 4, 2, 3)
|
537 |
.reshape(BV, C, H, W)
|
538 |
)
|
|
|
634 |
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
635 |
)
|
636 |
|
637 |
+
def forward(self, x,views):
|
638 |
xs = []
|
639 |
for attn, net in zip(self.attns, self.nets):
|
640 |
x = net(x)
|
641 |
if attn:
|
642 |
+
x = attn(x,views)
|
643 |
xs.append(x)
|
644 |
if self.downsample:
|
645 |
x = self.downsample(x)
|
|
|
672 |
self.nets = nn.ModuleList(nets)
|
673 |
self.attns = nn.ModuleList(attns)
|
674 |
|
675 |
+
def forward(self, x, views):
|
676 |
x = self.nets[0](x)
|
677 |
for attn, net in zip(self.attns, self.nets[1:]):
|
678 |
if attn:
|
679 |
+
x = attn(x,views)
|
680 |
x = net(x)
|
681 |
return x
|
682 |
|
|
|
717 |
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
718 |
)
|
719 |
|
720 |
+
def forward(self, x, xs,views): ###
|
721 |
for attn, net in zip(self.attns, self.nets):
|
722 |
res_x = xs[-1]
|
723 |
xs = xs[:-1]
|
724 |
x = torch.cat([x, res_x], dim=1)
|
725 |
x = net(x)
|
726 |
if attn:
|
727 |
+
x = attn(x,views) ##
|
728 |
if self.upsample:
|
729 |
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
730 |
x = self.upsample(x)
|
|
|
798 |
up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
|
799 |
)
|
800 |
|
801 |
+
def forward(self, x,views):###
|
802 |
x = self.conv_in(x)
|
803 |
xss = [x]
|
804 |
for block in self.down_blocks:
|
805 |
+
x, xs = block(x,views)###
|
806 |
xss.extend(xs)
|
807 |
+
x = self.mid_block(x,views) ###
|
808 |
for block in self.up_blocks:
|
809 |
xs = xss[-len(block.nets) :]
|
810 |
xss = xss[: -len(block.nets)]
|
811 |
+
x = block(x, xs,views)
|
812 |
x = self.norm_out(x)
|
813 |
x = F.silu(x)
|
814 |
x = self.conv_out(x)
|