Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple | |
from diffusers.models.modeling_utils import ModelMixin | |
class PointNet(ModelMixin): | |
def __init__( | |
self, | |
conditioning_channels: int = 1, | |
out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
downsamples: Tuple[int] = (6, 2, 2, 2) | |
): | |
super(PointNet, self).__init__() | |
self.blocks = nn.ModuleList() | |
current_channels = conditioning_channels | |
# 构造卷积块 | |
for out_channel, downsample in zip(out_channels, downsamples): | |
layers = [] | |
for _ in range(downsample // 2): | |
layers.append(nn.Conv2d(in_channels=current_channels, out_channels=out_channel, kernel_size=3, stride=2, padding=1)) | |
layers.append(nn.SiLU()) | |
current_channels = out_channel | |
self.blocks.append(nn.Sequential(*layers)) | |
def forward(self, x): | |
embeddings = [] | |
embedding = x | |
for block in self.blocks: | |
embedding = block(embedding) | |
B, C, H, W = embedding.shape | |
embeddings.append(embedding.view(B, C, H * W).transpose(1, 2)) | |
# embeddings.append(embedding) | |
return embeddings | |
if __name__ == "__main__": | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f'Using device: {device}') | |
model = PointNet().to(device) | |
dummy_input = torch.randn(1, 1, 288, 512).to(device) # Batch size = 1, Channels = 1, Height = 288, Width = 512 | |
embeddings = model(dummy_input) | |
for i, embedding in enumerate(embeddings): | |
print(f"Output at layer {i + 1}:", embedding.shape) |