File size: 2,239 Bytes
3c55139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn

from tok.ar_dtok.ar_model import ARModel
from tok.ar_dtok.vqvae import VQVAE
from tok.ta_tok import TextAlignedTokenizer


class MMAutoEncoder(nn.Module):
    def __init__(self,
        ar_path,
        encoder_path, decoder_path, 
        encoder_args={}, decoder_args={}):
        super().__init__()
        self.ar_model = ARModel.from_checkpoint(ar_path)

        self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
        self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)

    def ar_sample(self, x, args):
        x = self.ar_model.sample(
            x,
            cfg_scale=args.get('cfg_scale', 1.0),
            cfg_interval=args.get('cfg_interval', -1),
            temperature=args.get('temperature', 1.0),
            top_k=args.get('top_k', 0),
            top_p=args.get('top_p', 1.0)
        )
        return x

    def post_process(self, x):
        x = x.cpu().float().clamp(0., 1.) * 255.
        x = x.permute(0, 2, 3, 1) # [b, h, w, c]
        x = x.to(torch.uint8)
        return x
    
    def encode(self, x):
        return self.encoder(x.to(self.encoder.dtype))['encoded']
    
    def get_encoder_indices(self, x):
        # img -> encoder -> indices
        return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep']
    
    @torch.inference_mode()
    def decode_from_encoder_indices(self, indices, args={}):
        # indices -> encoder feats -> ar -> decoder
        encoder_x = self.encoder.decode_from_bottleneck(indices)
        ar_indices = self.ar_sample(encoder_x, args)
        decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
        x = self.post_process(decoder_x)
        return x
    
    def decode_from_vqvae_indices(self, indices):
        decoder_x = self.decoder.decode_from_bottleneck(indices)
        x = self.post_process(decoder_x)
        return x
    
    @torch.inference_mode()
    def forward(self, x, args={}):
        encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded']
        ar_indices = self.ar_sample(encoder_x, args)
        decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
        x = self.post_process(decoder_x)
        return x