Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from tok.ar_dtok.ar_model import ARModel | |
from tok.ar_dtok.vqvae import VQVAE | |
from tok.ta_tok import TextAlignedTokenizer | |
class MMAutoEncoder(nn.Module): | |
def __init__(self, | |
ar_path, | |
encoder_path, decoder_path, | |
encoder_args={}, decoder_args={}): | |
super().__init__() | |
self.ar_model = ARModel.from_checkpoint(ar_path) | |
self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args) | |
self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args) | |
def ar_sample(self, x, args): | |
x = self.ar_model.sample( | |
x, | |
cfg_scale=args.get('cfg_scale', 1.0), | |
cfg_interval=args.get('cfg_interval', -1), | |
temperature=args.get('temperature', 1.0), | |
top_k=args.get('top_k', 0), | |
top_p=args.get('top_p', 1.0) | |
) | |
return x | |
def post_process(self, x): | |
x = x.cpu().float().clamp(0., 1.) * 255. | |
x = x.permute(0, 2, 3, 1) # [b, h, w, c] | |
x = x.to(torch.uint8) | |
return x | |
def encode(self, x): | |
return self.encoder(x.to(self.encoder.dtype))['encoded'] | |
def get_encoder_indices(self, x): | |
# img -> encoder -> indices | |
return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep'] | |
def decode_from_encoder_indices(self, indices, args={}): | |
# indices -> encoder feats -> ar -> decoder | |
encoder_x = self.encoder.decode_from_bottleneck(indices) | |
ar_indices = self.ar_sample(encoder_x, args) | |
decoder_x = self.decoder.decode_from_bottleneck(ar_indices) | |
x = self.post_process(decoder_x) | |
return x | |
def decode_from_vqvae_indices(self, indices): | |
decoder_x = self.decoder.decode_from_bottleneck(indices) | |
x = self.post_process(decoder_x) | |
return x | |
def forward(self, x, args={}): | |
encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded'] | |
ar_indices = self.ar_sample(encoder_x, args) | |
decoder_x = self.decoder.decode_from_bottleneck(ar_indices) | |
x = self.post_process(decoder_x) | |
return x |