File size: 4,387 Bytes
d6d7648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math
from dataclasses import dataclass
from typing import Optional, TypeVar, Tuple, Any

T = TypeVar('T', bound='Module')
from einops import rearrange, repeat

import numpy as np
import torch
import torch.nn as nn

from transformers.utils import ModelOutput
from transformers.modeling_outputs import BaseModelOutputWithPooling

from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config

from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType


@dataclass
class ImageBindSegmaskAudioEncoderOutput(ModelOutput):
	"""
	Args:
		text_embeds(`torch.Tensor` of shape `(batch_size, output_dim`):
			The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
		image_embeds(`torch.Tensor` of shape `(batch_size, output_dim`):
			The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
		text_model_output(`BaseModelOutputWithPooling`):
			The output of the [`CLIPTextModel`].
		vision_model_output(`BaseModelOutputWithPooling`):
			The output of the [`CLIPVisionModel`].
	"""
	audio_embeds: torch.Tensor = None
	audio_encodings: torch.Tensor = None
	audio_segment_masks: torch.BoolTensor = None
	
	def to_tuple(self) -> Tuple[Any]:
		return tuple(self[k] for k in self.keys())


class ImageBindSegmaskAudioEncoder(ModelMixin, ConfigMixin):
	
	@register_to_config
	def __init__(self,
	             n_segment=4,
	             pretrained_model_name="imagebind-huge"
	             ):
		super().__init__()
		self.n_segment = n_segment
		
		self.pretrained_model_name = pretrained_model_name
		if pretrained_model_name == "imagebind-huge":
			pretrained_model = imagebind_model.imagebind_huge(pretrained=True)
		
		self.preprocessor = pretrained_model.modality_preprocessors[ModalityType.AUDIO]
		self.trunk = pretrained_model.modality_trunks[ModalityType.AUDIO]
		self.head = pretrained_model.modality_heads[ModalityType.AUDIO]
		self.postprocessor = pretrained_model.modality_postprocessors[ModalityType.AUDIO]
		self.final_layer_norm = nn.LayerNorm(normalized_shape=768, eps=1e-6)
	
	def _auto_split(self, n, n_chunk):
		'''
			automatically split into chunks with n_ele no differ by 1
			if n is not dividible by n_chunk, extra one's will be added to the heading chunks
		'''
		chunk_size = int(math.ceil(n / n_chunk))
		assert chunk_size >= 1, chunk_size
		
		chunk_start_indices = np.round(np.linspace(0, n - chunk_size, n_chunk, endpoint=True)).astype(np.int32)
		
		mask = torch.zeros(n_chunk, n).bool()
		for chunk_index, chunk_start_index in enumerate(chunk_start_indices):
			mask[chunk_index, chunk_start_index:chunk_start_index + chunk_size] = 1
		mask = mask.contiguous()
		assert mask.long().sum() == chunk_size * n_chunk, mask.long().sum()
		
		return mask
	
	def forward(self,
	            input_features: Optional[torch.Tensor],
	            normalize: bool = False,
	            return_dict: Optional[bool] = None):
		
		n_segment = self.n_segment
		
		# 1. reshape to imagebind input
		batchsize = input_features.size(0)
		
		# 2. patchify images and add positional embedding and
		audio_inputs = self.preprocessor(input_features)
		trunk_inputs = audio_inputs["trunk"]  # dict of {"tokens": (b, l, d)}
		
		# 3. get audio encoder output
		audio_encodings = self.trunk(**trunk_inputs)  # w/o layer norm (b, seq_len, c)
		head_inputs = audio_inputs["head"]
		cls_embeds = self.head(audio_encodings, **head_inputs)
		# normalize and logit scaling
		if normalize:
			cls_embeds = self.postprocessor(cls_embeds)  # (b, c)
		audio_encodings = self.final_layer_norm(audio_encodings)
		
		# 4. get segment masks
		n, t = 12, 19  # hard code
		segment_mask = self._auto_split(t, n_segment).unsqueeze(1).expand(n_segment, n, t).contiguous()  # (s, n, t)
		segment_mask = rearrange(
			segment_mask, "s n t -> s (n t)"
		)
		segment_mask = torch.cat([
			torch.ones(n_segment, 1).bool(),
			segment_mask
		], dim=1)  # (s, 1+n*t)
		
		segment_masks = repeat(segment_mask, "n s -> b n s", b=batchsize).contiguous().bool().to(self.device)
		
		if not return_dict:
			return cls_embeds, audio_encodings, segment_masks
		
		return ImageBindSegmaskAudioEncoderOutput(
			audio_embeds=cls_embeds,
			audio_encodings=audio_encodings,
			audio_segment_masks=segment_masks
		)