Tar / tok /mm_autoencoder.py
Jiaming Han
init
3c55139
import torch
import torch.nn as nn
from tok.ar_dtok.ar_model import ARModel
from tok.ar_dtok.vqvae import VQVAE
from tok.ta_tok import TextAlignedTokenizer
class MMAutoEncoder(nn.Module):
def __init__(self,
ar_path,
encoder_path, decoder_path,
encoder_args={}, decoder_args={}):
super().__init__()
self.ar_model = ARModel.from_checkpoint(ar_path)
self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
def ar_sample(self, x, args):
x = self.ar_model.sample(
x,
cfg_scale=args.get('cfg_scale', 1.0),
cfg_interval=args.get('cfg_interval', -1),
temperature=args.get('temperature', 1.0),
top_k=args.get('top_k', 0),
top_p=args.get('top_p', 1.0)
)
return x
def post_process(self, x):
x = x.cpu().float().clamp(0., 1.) * 255.
x = x.permute(0, 2, 3, 1) # [b, h, w, c]
x = x.to(torch.uint8)
return x
def encode(self, x):
return self.encoder(x.to(self.encoder.dtype))['encoded']
def get_encoder_indices(self, x):
# img -> encoder -> indices
return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep']
@torch.inference_mode()
def decode_from_encoder_indices(self, indices, args={}):
# indices -> encoder feats -> ar -> decoder
encoder_x = self.encoder.decode_from_bottleneck(indices)
ar_indices = self.ar_sample(encoder_x, args)
decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
x = self.post_process(decoder_x)
return x
def decode_from_vqvae_indices(self, indices):
decoder_x = self.decoder.decode_from_bottleneck(indices)
x = self.post_process(decoder_x)
return x
@torch.inference_mode()
def forward(self, x, args={}):
encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded']
ar_indices = self.ar_sample(encoder_x, args)
decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
x = self.post_process(decoder_x)
return x