Spaces:
Runtime error
Runtime error
File size: 786 Bytes
589b7f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
from ldm.modules.attention import BasicTransformerBlock
from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
import torch.nn.functional as F
class GroundingDownsampler(nn.Module):
def __init__(self, out_dim=1):
super().__init__()
self.out_dim = out_dim
# No learnable params for hed edge map, just downsample it with bicubic
def forward(self, grounding_extra_input):
# this is actually gary scale, but converted to rgb in dataset, information redudant
grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1)
out = torch.nn.functional.interpolate(grounding_extra_input, (64,64), mode='bicubic')
assert out.shape[1] == self.out_dim
return out
|