File size: 904 Bytes
26fd00c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
class CoordStage(object):
def __init__(self, n_embed, down_factor):
self.n_embed = n_embed
self.down_factor = down_factor
def eval(self):
return self
def encode(self, c):
"""fake vqmodel interface"""
assert 0.0 <= c.min() and c.max() <= 1.0
b,ch,h,w = c.shape
assert ch == 1
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
mode="area")
c = c.clamp(0.0, 1.0)
c = self.n_embed*c
c_quant = c.round()
c_ind = c_quant.to(dtype=torch.long)
info = None, None, c_ind
return c_quant, None, info
def decode(self, c):
c = c/self.n_embed
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
mode="nearest")
return c
|