NeMo

Causal Tokenizer may not be perfectly causal.

#2
by koukyo1994 - opened

I found odd phenomenon when I tested Cosmos-Tokenizer-DV4x8x8.

I prepared a video with >9 frames. When I passed the first 4 frames like this, I got quantized indices of the shape (1, 2, 36, 64).

# the shape of frames is (1,  3,  24,  288, 512)
encoder = CausalVideoTokenizer(checkpoint_enc="/path/to/Cosmos-Tokenizer-DV4x8x8/encoder.jit")
input_ids_4_frames = encoder.encode(frames[:, :, :4])[0]
input_ids_4_frames.size()
# ==> (1, 2, 36, 64)

When I passed the first 5 frames like this, I got quantized indices of the shape (1, 2, 36, 64).

input_ids_5_frames = encoder.encode(frames[:, :, :5])[0]
input_ids_5_frames.size()
# ==> (1, 2, 36, 64)

I compared the 0-th tensor of input_ids_4_frames and input_ids_5_frames, and they perfectly matched each other. This makes perfect sense since the tensor at index 0 corresponds to the very first frame.

(input_ids_4_frames[0, 0] == input_ids_5_frames[0, 0]).all()
# ==> tensor(True, device='cuda:0')

However, when I passed the first 9 frames and compared the 0-th tensor of the output with that of input_ids_5_frames, I found that the results did not match.

input_ids_9_frames = encoder.encode(frames[:, :, :9])[0]
input_ids_9_frames.size()
# ==> (1, 3, 36, 64)
(input_ids_5_frames[0, 0] == input_ids_9_frames[0, 0]).all()
# ==> tensor(False, device='cuda:0')
(input_ids_5_frames[0, 0] == input_ids_9_frames[0, 0]).float().mean()
# ==> tensor(0.8711, device='cuda:0')

This means that the encoding result of the previous frame has changed due to the number of subsequent frames changing, which suggests that the encoder may not be completely Causal. What do you think?

Sign up or log in to comment