Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
import torch | |
import torch.nn as nn | |
class BufferList(nn.Module): | |
def __init__(self, buffers): | |
super(BufferList, self).__init__() | |
for i, buffer in enumerate(buffers): | |
self.register_buffer(str(i), buffer, persistent=False) | |
def __len__(self): | |
return len(self._buffers) | |
def __iter__(self): | |
return iter(self._buffers.values()) | |
class PointGenerator(nn.Module): | |
def __init__(self, strides, buffer_size, offset=False): | |
super(PointGenerator, self).__init__() | |
reg_range, last = [], 0 | |
for stride in strides[1:]: | |
reg_range.append((last, stride)) | |
last = stride | |
reg_range.append((last, float('inf'))) | |
self.strides = strides | |
self.reg_range = reg_range | |
self.buffer_size = buffer_size | |
self.offset = offset | |
self.buffer = self._cache_points() | |
def _cache_points(self): | |
buffer_list = [] | |
for stride, reg_range in zip(self.strides, self.reg_range): | |
reg_range = torch.Tensor([reg_range]) | |
lv_stride = torch.Tensor([stride]) | |
points = torch.arange(0, self.buffer_size, stride)[:, None] | |
if self.offset: | |
points += 0.5 * stride | |
reg_range = reg_range.repeat(points.size(0), 1) | |
lv_stride = lv_stride.repeat(points.size(0), 1) | |
buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1)) | |
buffer = BufferList(buffer_list) | |
return buffer | |
def forward(self, pymid): | |
assert self.strides[0] == 1 | |
# video_size = pymid[0].size(1) | |
points = [] | |
sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid)) | |
for size, buffer in zip(sizes, self.buffer): | |
if size == 0: | |
continue | |
assert size <= buffer.size(0), 'reached max buffer size' | |
point = buffer[:size, :].clone() | |
# point[:, 0] /= video_size | |
points.append(point) | |
points = torch.cat(points) | |
return points | |