|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from tok.ar_dtok.ar_model import ARModel |
|
from tok.ar_dtok.vqvae import VQVAE |
|
from tok.ta_tok import TextAlignedTokenizer |
|
|
|
|
|
class MMAutoEncoder(nn.Module): |
|
def __init__(self, |
|
ar_path_dict, |
|
encoder_path, decoder_path, |
|
encoder_args={}, decoder_args={}): |
|
super().__init__() |
|
self.ar_model = nn.ModuleDict({resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()}) |
|
|
|
self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args) |
|
self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args) |
|
|
|
def ar_sample(self, x, args): |
|
resolution = args.get("resolution", "1024px") |
|
x = self.ar_model[resolution].sample( |
|
x, |
|
cfg_scale=args.get('cfg_scale', 1.0), |
|
cfg_interval=args.get('cfg_interval', -1), |
|
temperature=args.get('temperature', 1.0), |
|
top_k=args.get('top_k', 0), |
|
top_p=args.get('top_p', 1.0) |
|
) |
|
return x |
|
|
|
def post_process(self, x): |
|
x = x.cpu().float().clamp(0., 1.) * 255. |
|
x = x.permute(0, 2, 3, 1) |
|
x = x.to(torch.uint8) |
|
return x |
|
|
|
def encode(self, x): |
|
return self.encoder(x.to(self.encoder.dtype))['encoded'] |
|
|
|
def get_encoder_indices(self, x): |
|
|
|
return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep'] |
|
|
|
@torch.inference_mode() |
|
def decode_from_encoder_indices(self, indices, args={}): |
|
|
|
encoder_x = self.encoder.decode_from_bottleneck(indices) |
|
ar_indices = self.ar_sample(encoder_x, args) |
|
decoder_x = self.decoder.decode_from_bottleneck(ar_indices) |
|
x = self.post_process(decoder_x) |
|
return x |
|
|
|
def decode_from_vqvae_indices(self, indices): |
|
decoder_x = self.decoder.decode_from_bottleneck(indices) |
|
x = self.post_process(decoder_x) |
|
return x |
|
|
|
@torch.inference_mode() |
|
def forward(self, x, args={}): |
|
encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded'] |
|
ar_indices = self.ar_sample(encoder_x, args) |
|
decoder_x = self.decoder.decode_from_bottleneck(ar_indices) |
|
x = self.post_process(decoder_x) |
|
return x |