File size: 13,176 Bytes
113884e
 
 
 
 
 
 
 
 
 
c9ddddb
113884e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9ddddb
 
113884e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MotionEmbedding(nn.Module):

    def __init__(self, embed_dim: int = None, max_seq_length: int = 32, wh: int = 1):
        super().__init__()
        self.embed = nn.Parameter(torch.zeros(wh, max_seq_length, embed_dim))
        # print('register spatial motion embedding with', wh)

        self.scale = 1.0
        self.trained_length = -1

    def set_scale(self, scale: float):
        self.scale = scale

    def set_lengths(self, trained_length: int):
        if trained_length > self.embed.shape[1] or trained_length <= 0:
            raise ValueError("Trained length is out of bounds")
        self.trained_length = trained_length

    def forward(self, x):
        _, seq_length, _ = x.shape  # seq_length here is the target sequence length for x
        # print('seq_length',seq_length)
        # Assuming self.embed is [batch, frames, dim]
        embeddings = self.embed[:, :seq_length]  # Initial slice, may not be necessary depending on the interpolation logic

        # Check if interpolation is needed
        if self.trained_length != -1 and seq_length != self.trained_length:
            # Interpolate embeddings to match x's sequence length
            # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
            embeddings = embeddings.permute(0, 2, 1)  # Now [batch, dim, frames]
            embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
            embeddings = embeddings.permute(0, 2, 1)  # Revert to [batch, frames, dim]

        # Ensure the interpolated embeddings match the sequence length of x
        if embeddings.shape[1] != seq_length:
            raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")

        if x.shape[0] != embeddings.shape[0]:
            x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
        else:
            # Now embeddings should have the shape [batch, seq_length, dim] matching x
            x = x + embeddings * self.scale  # Assuming broadcasting is desired over the batch and dim dimensions

        return x


    def forward_average(self, x):
        _, seq_length, _ = x.shape  # seq_length here is the target sequence length for x
        # print('seq_length',seq_length)
        # Assuming self.embed is [batch, frames, dim]
        embeddings = self.embed[:, :seq_length]  # Initial slice, may not be necessary depending on the interpolation logic

        # Check if interpolation is needed
        if self.trained_length != -1 and seq_length != self.trained_length:
            # Interpolate embeddings to match x's sequence length
            # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
            embeddings = embeddings.permute(0, 2, 1)  # Now [batch, dim, frames]
            embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
            embeddings = embeddings.permute(0, 2, 1)  # Revert to [batch, frames, dim]

        # Ensure the interpolated embeddings match the sequence length of x
        if embeddings.shape[1] != seq_length:
            raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")

        embeddings_mean = embeddings.mean(dim=1, keepdim=True)
        embeddings = embeddings - embeddings_mean
        if x.shape[0] != embeddings.shape[0]:
            x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
        else:
            # Now embeddings should have the shape [batch, seq_length, dim] matching x
            x = x + embeddings * self.scale  # Assuming broadcasting is desired over the batch and dim dimensions

        return x

    def forward_frameSubtraction(self, x):
        _, seq_length, _ = x.shape  # seq_length here is the target sequence length for x
        # print('seq_length',seq_length)
        # Assuming self.embed is [batch, frames, dim]
        embeddings = self.embed[:, :seq_length]  # Initial slice, may not be necessary depending on the interpolation logic

        # Check if interpolation is needed
        if self.trained_length != -1 and seq_length != self.trained_length:
            # Interpolate embeddings to match x's sequence length
            # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
            embeddings = embeddings.permute(0, 2, 1)  # Now [batch, dim, frames]
            embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
            embeddings = embeddings.permute(0, 2, 1)  # Revert to [batch, frames, dim]

        # Ensure the interpolated embeddings match the sequence length of x
        if embeddings.shape[1] != seq_length:
            raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")

        embeddings_subtraction = embeddings[:,1:] - embeddings[:,:-1]
        
        embeddings = embeddings.clone().detach()
        embeddings[:,1:] = embeddings_subtraction
        
        # first frame minus mean
        # embeddings[:,0:1] = embeddings[:,0:1] - embeddings.mean(dim=1, keepdim=True)

        if x.shape[0] != embeddings.shape[0]:
            x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
        else:
            # Now embeddings should have the shape [batch, seq_length, dim] matching x
            x = x + embeddings * self.scale  # Assuming broadcasting is desired over the batch and dim dimensions

        return x

class MotionEmbeddingSpatial(nn.Module):

    def __init__(self, h: int = None, w: int = None, embed_dim: int = None, max_seq_length: int = 32):
        super().__init__()
        self.embed = nn.Parameter(torch.zeros(h*w, max_seq_length, embed_dim))
        self.scale = 1.0
        self.trained_length = -1

    def set_scale(self, scale: float):
        self.scale = scale

    def set_lengths(self, trained_length: int):
        if trained_length > self.embed.shape[1] or trained_length <= 0:
            raise ValueError("Trained length is out of bounds")
        self.trained_length = trained_length

    def forward(self, x):
        _, seq_length, _ = x.shape  # seq_length here is the target sequence length for x

        # Assuming self.embed is [batch, frames, dim]
        embeddings = self.embed[:, :seq_length]  # Initial slice, may not be necessary depending on the interpolation logic

        # Check if interpolation is needed
        if self.trained_length != -1 and seq_length != self.trained_length:
            # Interpolate embeddings to match x's sequence length
            # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames
            embeddings = embeddings.permute(0, 2, 1)  # Now [batch, dim, frames]
            embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False)
            embeddings = embeddings.permute(0, 2, 1)  # Revert to [batch, frames, dim]

        # Ensure the interpolated embeddings match the sequence length of x
        if embeddings.shape[1] != seq_length:
            raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}")

        if x.shape[0] != embeddings.shape[0]:
            x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale
        else:
            # Now embeddings should have the shape [batch, seq_length, dim] matching x
            x = x + embeddings * self.scale  # Assuming broadcasting is desired over the batch and dim dimensions

        return x


def inject_motion_embeddings(model, combinations=None, config=None):
    spatial_shape=np.array([config.dataset.height,config.dataset.width])
    shape32 = np.ceil(spatial_shape/32).astype(int)
    shape16 = np.ceil(spatial_shape/16).astype(int)
    spatial_name = 'vSpatial'
    replacement_dict = {}
    # support for 32 frames
    max_seq_length = 32
    inject_layers = []
    for name, module in model.named_modules():
        
        # check if the module is temp_attention
        PETemporal = '.temp_attentions.' in name

        if not(PETemporal and re.search(r'transformer_blocks\.\d+$', name)):
            continue

        if not ([name.split('_')[0], module.norm1.normalized_shape[0]] in combinations):
            continue
        
        replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(max_seq_length=max_seq_length, embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device)
         
    replacement_keys = list(set(replacement_dict.keys()))
    temp_attn_list =    [name.replace('pos_embed','attn1') for name in replacement_keys] + \
                        [name.replace('pos_embed','attn2') for name in replacement_keys]
    embed_dims = [replacement_dict[replacement_keys[i]].embed.shape[2] for i in range(len(replacement_keys))]
    
    for temp_attn_index,temp_attn in enumerate(temp_attn_list):
        place_in_net = temp_attn.split('_')[0]
        pattern = r'(\d+)\.temp_attentions'
        match = re.search(pattern, temp_attn)
        place_in_net = temp_attn.split('_')[0]
        index_in_net = match.group(1)
        h,w = None,None
        if place_in_net == 'up':
            if index_in_net == "1":
                h, w = shape32
            elif index_in_net == "2":
                h, w = shape16
        elif place_in_net == 'down':
            if index_in_net == "1":
                h, w = shape16
            elif index_in_net == "2":
                h, w = shape32
        
        replacement_dict[temp_attn+'.'+spatial_name] = \
            MotionEmbedding(
                wh=h*w,
                embed_dim=embed_dims[temp_attn_index%len(replacement_keys)]
                ).to(dtype=model.dtype, device=model.device)

    for name, new_module in replacement_dict.items():
        parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
        module_name = name.rsplit('.', 1)[-1]
        parent_module = model
        if parent_name:
            parent_module = dict(model.named_modules())[parent_name]

        if [parent_name.split('_')[0], new_module.embed.shape[-1]] in combinations:
            inject_layers.append(name)
            setattr(parent_module, module_name, new_module)

    inject_layers = list(set(inject_layers))
    # for name in inject_layers:
        # print(f"Injecting motion embedding at {name}")

    parameters_list = []
    for name, para in model.named_parameters():
        if 'pos_embed' in name or spatial_name in name:
            parameters_list.append(para)
            para.requires_grad = True
        else:
            para.requires_grad = False

    return parameters_list, inject_layers

def save_motion_embeddings(model, file_path):
    # Extract motion embedding from all instances of MotionEmbedding
    motion_embeddings = {
        name: module.embed
        for name, module in model.named_modules()
        if isinstance(module, MotionEmbedding) or isinstance(module, MotionEmbeddingSpatial)
    }
    # Save the motion embeddings to the specified file path
    torch.save(motion_embeddings, file_path)

def load_motion_embeddings(model, saved_embeddings):
    for key, embedding in saved_embeddings.items():
        # Extract parent module and module name from the key
        parent_name = key.rsplit('.', 1)[0] if '.' in key else ''
        module_name = key.rsplit('.', 1)[-1]

        # Retrieve the parent module
        parent_module = model
        if parent_name:
            parent_module = dict(model.named_modules())[parent_name]

        # Create a new MotionEmbedding instance with the correct dimensions

        new_module = MotionEmbedding(wh = embedding.shape[0],embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2])

        # Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter
        # Ensure the embedding is on the correct device and has the correct dtype
        new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device))

        # Replace the corresponding module in the model with the new MotionEmbedding instance
        setattr(parent_module, module_name, new_module)

def set_motion_embedding_scale(model, scale_value):
    # Iterate over all modules in the model
    for _, module in model.named_modules():
        # Check if the module is an instance of MotionEmbedding
        if isinstance(module, MotionEmbedding):
            # Set the scale attribute to the specified value
            module.scale = scale_value

def set_motion_embedding_length(model, trained_length):
    # Iterate over all modules in the model
    for _, module in model.named_modules():
        # Check if the module is an instance of MotionEmbedding
        if isinstance(module, MotionEmbedding):
            # Set the length to the specified value
            module.trained_length = trained_length