jadechoghari commited on
Commit
a051d95
1 Parent(s): 2899431

Create merge.py

Browse files
Files changed (1) hide show
  1. merge.py +767 -0
merge.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Callable
3
+
4
+
5
+ def do_nothing(x: torch.Tensor, mode: str = None):
6
+ return x
7
+
8
+
9
+ def mps_gather_workaround(input, dim, index):
10
+ if input.shape[-1] == 1:
11
+ return torch.gather(
12
+ input.unsqueeze(-1),
13
+ dim - 1 if dim < 0 else dim,
14
+ index.unsqueeze(-1)
15
+ ).squeeze(-1)
16
+ else:
17
+ return torch.gather(input, dim, index)
18
+
19
+ # For Local Token Merging
20
+ def bipartite_soft_matching_randframe(metric: torch.Tensor,
21
+ F: int, ratio: float, unm_pre: int, generator: torch.Generator,
22
+ target_stride: int = 4, align_batch: bool = False,
23
+ merge_mode: str = "replace") -> Tuple[Callable, Callable, dict]:
24
+ """
25
+ Partitions the multi-frame tokens into src and dst and merges ratio of src tokens from src to dst.
26
+ Dst tokens are partitioned by choosing one random frame.
27
+
28
+ Args:
29
+ - metric [B, N, C]: metric to use for similarity.
30
+ - F: frame number.
31
+ - ratio: ratio of src tokens to be removed (by merging).
32
+ - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...]
33
+ - generator: random number generator
34
+ - target_stride: stride of target frame.
35
+ - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP.
36
+ - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token.
37
+
38
+ Returns:
39
+ Merge and unmerge operation according to the matching result. Return a dict including other values.
40
+ """
41
+ B, N, _ = metric.shape
42
+ # Compute pre-frame token number. N = unm_pre + tnum * F.
43
+ tnum = (N - unm_pre) // F
44
+
45
+ if ratio <= 0:
46
+ return do_nothing, do_nothing, {"unm_num": tnum}
47
+
48
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
49
+
50
+ with torch.no_grad():
51
+ # Prepare idx buffer. Ignore previous unmerged tokens.
52
+ idx_buffer = torch.arange(
53
+ N - unm_pre, device=metric.device, dtype=torch.int64)
54
+
55
+ # Select the random target frame.
56
+ target_stride = min(target_stride, F)
57
+ randf = torch.randint(0, target_stride, torch.Size(
58
+ [1]), generator=generator, device=generator.device)
59
+ dst_select = ((torch.div(idx_buffer, tnum, rounding_mode='floor')) %
60
+ target_stride == randf).to(torch.bool)
61
+
62
+ # a_idx: src index. b_idx: dst index
63
+ a_idx = idx_buffer[None, ~dst_select, None] + unm_pre
64
+ b_idx = idx_buffer[None, dst_select, None] + unm_pre
65
+
66
+ # Add unmerged tokens to dst.
67
+ unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[
68
+ None, :, None]
69
+ b_idx = torch.cat([b_idx, unm_buffer], dim=1)
70
+
71
+ # We're finished with these
72
+ del idx_buffer, unm_buffer
73
+
74
+ num_dst = b_idx.shape[1]
75
+
76
+ def split(x):
77
+ # Split src, dst tokens
78
+ b, n, c = x.shape
79
+ src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c))
80
+ dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c))
81
+ return src, dst
82
+
83
+ # Cosine similarity between src and dst tokens
84
+ metric = metric / metric.norm(dim=-1, keepdim=True)
85
+ a, b = split(metric)
86
+
87
+ scores = a @ b.transpose(-1, -2)
88
+
89
+ # Can't reduce more than the # tokens in src
90
+ r = min(a.shape[1], int(a.shape[1] * ratio))
91
+
92
+
93
+ if align_batch:
94
+ # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos).
95
+ # Find the most similar greedily among all samples.
96
+ scores = torch.cat([*scores], dim=-1)
97
+ node_max, node_idx = scores.max(dim=-1)
98
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
99
+
100
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
101
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
102
+ dst_idx = gather(node_idx[..., None],
103
+ dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1)
104
+
105
+ # Use the same matching result for all samples
106
+ unm_idx = unm_idx.expand(B, -1, -1)
107
+ src_idx = src_idx.expand(B, -1, -1)
108
+ dst_idx = dst_idx.expand(B, -1, -1)
109
+ else:
110
+
111
+ # Find the most similar greedily
112
+ node_max, node_idx = scores.max(dim=-1)
113
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
114
+
115
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
116
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
117
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
118
+
119
+ def merge(x: torch.Tensor, mode=None) -> torch.Tensor:
120
+ # Merge tokens according to matching result.
121
+ src, dst = split(x)
122
+ n, t1, c = src.shape
123
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
124
+
125
+ unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c))
126
+ mode = mode if mode is not None else merge_mode
127
+ if mode != "replace":
128
+ src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c))
129
+ # In other mode such as mean, combine matched src and dst tokens.
130
+ dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c),
131
+ src, reduce=mode, include_self=True)
132
+ # In replace mode, just cat unmerged tokens and dst tokens. Ignore src tokens.
133
+ return torch.cat([unm, dst], dim=1)
134
+
135
+ def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor:
136
+ # Unmerge tokens to original size according to matching result.
137
+ unm_len = unm_idx.shape[1]
138
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
139
+ b, _, c = unm.shape
140
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
141
+ # Restored src tokens take value from dst tokens
142
+ src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c))
143
+
144
+ # Combine back to the original shape
145
+ out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype)
146
+ # Scatter dst tokens
147
+ out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst)
148
+ # Scatter unmerged tokens
149
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1),
150
+ dim=1, index=u_idx).expand(-1, -1, c), src=unm)
151
+ # Scatter src tokens
152
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1),
153
+ dim=1, index=s_idx).expand(-1, -1, c), src=src)
154
+
155
+ return out
156
+
157
+ # Return number of tokens not merged.
158
+ ret_dict = {"unm_num": unm_idx.shape[1] if unm_idx.shape[1] is not None else 0}
159
+ return merge, unmerge, ret_dict
160
+
161
+
162
+ def bipartite_soft_matching_random2d_hier(metric: torch.Tensor, frame_num: int, ratio: float, unm_pre: int, generator: torch.Generator, target_stride: int = 4, adhere_src: bool = False, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2) -> Tuple[Callable, Callable]:
163
+ """
164
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
165
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
166
+
167
+ Args:
168
+ - metric [B, N, C]: metric to use for similarity
169
+ - w: image width in tokens
170
+ - h: image height in tokens
171
+ - sx: stride in the x dimension for dst, must divide w
172
+ - sy: stride in the y dimension for dst, must divide h
173
+ - r: number of tokens to remove (by merging)
174
+ - no_rand: if true, disable randomness (use top left corner only)
175
+ - rand_seed: if no_rand is false, and if not None, sets random seed.
176
+ """
177
+ B, N, _ = metric.shape
178
+ F = frame_num
179
+ nf = (N - unm_pre) // F
180
+
181
+ if ratio <= 0:
182
+ return do_nothing, do_nothing
183
+
184
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
185
+
186
+ with torch.no_grad():
187
+
188
+
189
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
190
+ idx_buffer = torch.arange(N - unm_pre, device=metric.device, dtype=torch.int64)
191
+
192
+
193
+ # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf
194
+ # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn
195
+ # dst_select = torch.zeros_like(idx_buffer).to(torch.bool)
196
+ # dst_select[dst_indexes] = 1
197
+ max_f = min(target_stride, F)
198
+ randn = torch.randint(0, max_f, torch.Size([1]), generator=generator, device = generator.device)
199
+ # randn = 0
200
+ dst_select = ((torch.div(idx_buffer, nf, rounding_mode='floor')) % max_f == randn).to(torch.bool)
201
+ # dst_select = ((idx_buffer // nf) == 0).to(torch.bool)
202
+ a_idx = idx_buffer[None, ~dst_select, None] + unm_pre
203
+ b_idx = idx_buffer[None, dst_select, None] + unm_pre
204
+
205
+ unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[None,:,None]
206
+ b_idx = torch.cat([b_idx, unm_buffer], dim = 1)
207
+
208
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
209
+
210
+ # We're finished with these
211
+ del idx_buffer, unm_buffer
212
+
213
+ num_dst = b_idx.shape[1]
214
+
215
+ def split(x):
216
+ b, n, c = x.shape
217
+ src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c))
218
+ dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c))
219
+ return src, dst
220
+
221
+ def split_coord(coord):
222
+ b, n, c = coord.shape
223
+ src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c))
224
+ dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c))
225
+ return src, dst
226
+
227
+
228
+ # Cosine similarity between A and B
229
+ metric = metric / metric.norm(dim=-1, keepdim=True)
230
+ a, b = split(metric)
231
+
232
+
233
+ if coord is not None:
234
+ src_coord, dst_coord = split_coord(coord)
235
+ mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field
236
+
237
+
238
+ scores = a @ b.transpose(-1, -2)
239
+
240
+ if coord is not None:
241
+ scores[mask] = 0
242
+
243
+ # Can't reduce more than the # tokens in src
244
+ r = int(a.shape[1] * ratio)
245
+ r = min(a.shape[1], r)
246
+
247
+
248
+
249
+ if adhere_src:
250
+ # scores = torch.sum(scores, dim=0)
251
+ scores = torch.cat([*scores], dim = -1)
252
+ node_max, node_idx = scores.max(dim=-1)
253
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
254
+
255
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
256
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
257
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst
258
+
259
+ unm_idx = unm_idx.expand(B, -1, -1)
260
+ src_idx = src_idx.expand(B, -1, -1)
261
+ dst_idx = dst_idx.expand(B, -1, -1)
262
+ else:
263
+ # scores = torch.cat([*scores][1:], dim = -1)
264
+ # node_max, node_idx = scores.max(dim=-1)
265
+ # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
266
+
267
+ # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
268
+ # src_idx = edge_idx[..., :r, :] # Merged Tokens
269
+ # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst
270
+
271
+ # unm_idx = unm_idx.expand(B, -1, -1)
272
+ # src_idx = src_idx.expand(B, -1, -1)
273
+ # dst_idx = dst_idx.expand(B, -1, -1)
274
+
275
+
276
+ # Find the most similar greedily
277
+ node_max, node_idx = scores.max(dim=-1)
278
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
279
+
280
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
281
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
282
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
283
+
284
+ # if adhere_src:
285
+ # unm_idx[:,...] = unm_idx[0:1]
286
+ # src_idx[:,...] = src_idx[0:1]
287
+ # dst_idx[:,...] = dst_idx[0:1]
288
+
289
+ def merge(x: torch.Tensor, mode=None, b_select = None, **kwarg) -> torch.Tensor:
290
+ src, dst = split(x)
291
+ n, t1, c = src.shape
292
+ if b_select is not None:
293
+ if not isinstance(b_select, list):
294
+ b_select = [b_select]
295
+ u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select]
296
+ else:
297
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
298
+
299
+ unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c))
300
+ src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c))
301
+ mode = mode if mode is not None else merge_mode
302
+ if mode != "replace":
303
+ dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True)
304
+ # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add')
305
+
306
+ # dst_cnt = torch.ones_like(dst)
307
+ # src_ones = torch.ones_like(src)
308
+ # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add')
309
+
310
+ # dst = dst / dst_cnt
311
+ # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True)
312
+ # assert torch.allclose(dst1, dst2)
313
+
314
+ return torch.cat([unm, dst], dim=1)
315
+
316
+ def unmerge(x: torch.Tensor, b_select = None, unm_modi = None, **kwarg) -> torch.Tensor:
317
+ unm_len = unm_idx.shape[1]
318
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
319
+ b, _, c = unm.shape
320
+ if b_select is not None:
321
+ if not isinstance(b_select, list):
322
+ b_select = [b_select]
323
+ u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select]
324
+ else:
325
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
326
+ if unm_modi is not None:
327
+ if unm_modi == "zero":
328
+ unm = torch.zeros_like(unm)
329
+ src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c))
330
+
331
+ # Combine back to the original shape
332
+ out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype)
333
+ out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst)
334
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm)
335
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src)
336
+
337
+ return out
338
+
339
+ ret_dict = {"unm_num": unm_idx.shape[1]}
340
+ return merge, unmerge, ret_dict
341
+
342
+ # For Global Token Merging.
343
+ def bipartite_soft_matching_2s( metric: torch.Tensor,
344
+ src_len: int, ratio: float, align_batch: bool,
345
+ merge_mode: str = "replace", unmerge_chunk: int = 0) -> Tuple[Callable, Callable, dict]:
346
+ """
347
+ Partitions the tokens into src and dst and merges ratio of src tokens from src to dst.
348
+ Src tokens are partitioned as first src_len tokens. Others are dst tokens.
349
+
350
+ Args:
351
+ - metric [B, N, C]: metric to use for similarity.
352
+ - src_len: src token length. [ src | dst ]: [ src_len | N - src_len ]
353
+ - ratio: ratio of src tokens to be removed (by merging).
354
+ - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...]
355
+ - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP.
356
+ - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token.
357
+ - unmerge_chunk: return which partition in unmerge. 0 for src and 1 for dst.
358
+
359
+ Returns:
360
+ Merge and unmerge operation according to the matching result. Return a dict including other values.
361
+ """
362
+ B, N, _ = metric.shape
363
+
364
+ if ratio <= 0:
365
+ return do_nothing, do_nothing
366
+
367
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
368
+
369
+ with torch.no_grad():
370
+
371
+ idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64)
372
+
373
+ # [ src | dst ]: [ src_len | N - src_len ]
374
+ a_idx = idx_buffer[None, :src_len, None]
375
+ b_idx = idx_buffer[None, src_len:, None]
376
+
377
+ del idx_buffer
378
+
379
+ num_dst = b_idx.shape[1]
380
+
381
+ def split(x):
382
+ # Split src, dst tokens
383
+ b, n, c = x.shape
384
+ src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c))
385
+ dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c))
386
+ return src, dst
387
+
388
+ # Cosine similarity between src and dst tokens
389
+ metric = metric / metric.norm(dim=-1, keepdim=True)
390
+ a, b = split(metric)
391
+
392
+ scores = a @ b.transpose(-1, -2)
393
+
394
+ # Can't reduce more than the # tokens in src
395
+ r = min(a.shape[1], int(a.shape[1] * ratio))
396
+
397
+ if align_batch:
398
+ # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos).
399
+ # Find the most similar greedily among all samples.
400
+ scores = torch.cat([*scores], dim=-1)
401
+ node_max, node_idx = scores.max(dim=-1)
402
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
403
+
404
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
405
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
406
+ dst_idx = gather(node_idx[..., None],
407
+ dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1)
408
+
409
+ # Use the same matching result for all samples
410
+ unm_idx = unm_idx.expand(B, -1, -1)
411
+ src_idx = src_idx.expand(B, -1, -1)
412
+ dst_idx = dst_idx.expand(B, -1, -1)
413
+ else:
414
+
415
+ # Find the most similar greedily
416
+ node_max, node_idx = scores.max(dim=-1)
417
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
418
+
419
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
420
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
421
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
422
+
423
+ def merge(x: torch.Tensor, mode=None) -> torch.Tensor:
424
+ # Merge tokens according to matching result.
425
+ src, dst = split(x)
426
+ n, t1, c = src.shape
427
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
428
+
429
+ unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c))
430
+ mode = mode if mode is not None else merge_mode
431
+ if mode != "replace":
432
+ src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c))
433
+ # In other mode such as mean, combine matched src and dst tokens.
434
+ dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c),
435
+ src, reduce=mode, include_self=True)
436
+ # In replace mode, just cat unmerged tokens and dst tokens. Discard src tokens.
437
+ return torch.cat([unm, dst], dim=1)
438
+
439
+ def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor:
440
+ # Unmerge tokens to original size according to matching result.
441
+ unm_len = unm_idx.shape[1]
442
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
443
+ b, _, c = unm.shape
444
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
445
+ # Restored src tokens take value from dst tokens
446
+ src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c))
447
+
448
+ # Combine back to the original shape
449
+ out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype)
450
+ # Scatter dst tokens
451
+ out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst)
452
+ # Scatter unmerged tokens
453
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1),
454
+ dim=1, index=u_idx).expand(-1, -1, c), src=unm)
455
+ # Scatter src tokens
456
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1),
457
+ dim=1, index=s_idx).expand(-1, -1, c), src=src)
458
+
459
+ out = out[:, :src_len, :] if unmerge_chunk == 0 else out[:, src_len:, :]
460
+ return out
461
+
462
+ ret_dict = {"unm_num": unm_idx.shape[1]}
463
+ return merge, unmerge, ret_dict
464
+
465
+
466
+ # Original ToMe
467
+ def bipartite_soft_matching_random2d(metric: torch.Tensor,
468
+ w: int, h: int, sx: int, sy: int, r: int,
469
+ no_rand: bool = False,
470
+ generator: torch.Generator = None) -> Tuple[Callable, Callable]:
471
+ """
472
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
473
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
474
+
475
+ Args:
476
+ - metric [B, N, C]: metric to use for similarity
477
+ - w: image width in tokens
478
+ - h: image height in tokens
479
+ - sx: stride in the x dimension for dst, must divide w
480
+ - sy: stride in the y dimension for dst, must divide h
481
+ - r: number of tokens to remove (by merging)
482
+ - no_rand: if true, disable randomness (use top left corner only)
483
+ - rand_seed: if no_rand is false, and if not None, sets random seed.
484
+ """
485
+ B, N, _ = metric.shape
486
+
487
+ if r <= 0:
488
+ return do_nothing, do_nothing
489
+
490
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
491
+
492
+ with torch.no_grad():
493
+ hsy, wsx = h // sy, w // sx
494
+
495
+ # For each sy by sx kernel, randomly assign one token to be dst and the rest src
496
+ if no_rand:
497
+ rand_idx = torch.zeros(
498
+ hsy, wsx, 1, device=metric.device, dtype=torch.int64)
499
+ else:
500
+ rand_idx = torch.randint(
501
+ sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)
502
+
503
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
504
+ idx_buffer_view = torch.zeros(
505
+ hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
506
+ idx_buffer_view.scatter_(
507
+ dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
508
+ idx_buffer_view = idx_buffer_view.view(
509
+ hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
510
+
511
+ # Image is not divisible by sx or sy so we need to move it into a new buffer
512
+ if (hsy * sy) < h or (wsx * sx) < w:
513
+ idx_buffer = torch.zeros(
514
+ h, w, device=metric.device, dtype=torch.int64)
515
+ idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
516
+ else:
517
+ idx_buffer = idx_buffer_view
518
+
519
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
520
+ rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
521
+
522
+ # We're finished with these
523
+ del idx_buffer, idx_buffer_view
524
+
525
+ # rand_idx is currently dst|src, so split them
526
+ num_dst = hsy * wsx
527
+ a_idx = rand_idx[:, num_dst:, :] # src
528
+ b_idx = rand_idx[:, :num_dst, :] # dst
529
+
530
+ def split(x):
531
+ C = x.shape[-1]
532
+ src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
533
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
534
+ return src, dst
535
+
536
+ # Cosine similarity between A and B
537
+ metric = metric / metric.norm(dim=-1, keepdim=True)
538
+ a, b = split(metric)
539
+ scores = a @ b.transpose(-1, -2)
540
+
541
+ # Can't reduce more than the # tokens in src
542
+ r = min(a.shape[1], r)
543
+
544
+ # Find the most similar greedily
545
+ node_max, node_idx = scores.max(dim=-1)
546
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
547
+
548
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
549
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
550
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
551
+
552
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
553
+ src, dst = split(x)
554
+ n, t1, c = src.shape
555
+
556
+ unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
557
+ src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
558
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
559
+
560
+ return torch.cat([unm, dst], dim=1)
561
+
562
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
563
+ unm_len = unm_idx.shape[1]
564
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
565
+ _, _, c = unm.shape
566
+
567
+ src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
568
+
569
+ # Combine back to the original shape
570
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
571
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
572
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B,
573
+ a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
574
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B,
575
+ a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
576
+
577
+ return out
578
+
579
+ return merge, unmerge
580
+
581
+
582
+ def bipartite_soft_matching_2f(metric: torch.Tensor, src_len: int, ratio: float, adhere_src: bool, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2, unmerge_chunk = 0) -> Tuple[Callable, Callable]:
583
+ """
584
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
585
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
586
+
587
+ Args:
588
+ - metric [B, N, C]: metric to use for similarity
589
+ - w: image width in tokens
590
+ - h: image height in tokens
591
+ - sx: stride in the x dimension for dst, must divide w
592
+ - sy: stride in the y dimension for dst, must divide h
593
+ - r: number of tokens to remove (by merging)
594
+ - no_rand: if true, disable randomness (use top left corner only)
595
+ - rand_seed: if no_rand is false, and if not None, sets random seed.
596
+ """
597
+ B, N, _ = metric.shape
598
+
599
+ if ratio <= 0:
600
+ return do_nothing, do_nothing
601
+
602
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
603
+
604
+ with torch.no_grad():
605
+
606
+
607
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
608
+ idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64)
609
+
610
+
611
+ # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf
612
+ # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn
613
+ # dst_select = torch.zeros_like(idx_buffer).to(torch.bool)
614
+ # dst_select[dst_indexes] = 1
615
+ # randn = 0
616
+ # dst_select = ((idx_buffer // nf) == 0).to(torch.bool)
617
+ a_idx = idx_buffer[None, :src_len, None]
618
+ b_idx = idx_buffer[None, src_len:, None]
619
+
620
+
621
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
622
+
623
+ # We're finished with these
624
+ del idx_buffer
625
+
626
+ num_dst = b_idx.shape[1]
627
+
628
+ def split(x):
629
+ b, n, c = x.shape
630
+ src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c))
631
+ dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c))
632
+ return src, dst
633
+
634
+ def split_coord(coord):
635
+ b, n, c = coord.shape
636
+ src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c))
637
+ dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c))
638
+ return src, dst
639
+
640
+
641
+ # Cosine similarity between A and B
642
+ metric = metric / metric.norm(dim=-1, keepdim=True)
643
+ a, b = split(metric)
644
+
645
+
646
+ if coord is not None:
647
+ src_coord, dst_coord = split_coord(coord)
648
+ mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field
649
+
650
+
651
+ scores = a @ b.transpose(-1, -2)
652
+
653
+ if coord is not None:
654
+ scores[mask] = 0
655
+
656
+ # Can't reduce more than the # tokens in src
657
+ r = int(a.shape[1] * ratio)
658
+ r = min(a.shape[1], r)
659
+
660
+
661
+
662
+ if adhere_src:
663
+ scores = torch.cat([*scores], dim = -1)
664
+ # scores = torch.sum(scores, dim=0)
665
+ node_max, node_idx = scores.max(dim=-1)
666
+
667
+ # nscores = torch.cat([*scores], dim = -2)
668
+ # rev_node_max, rev_node_idx = nscores.max(dim = -2)
669
+
670
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
671
+
672
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
673
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
674
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst
675
+
676
+ unm_idx = unm_idx.expand(B, -1, -1)
677
+ src_idx = src_idx.expand(B, -1, -1)
678
+ dst_idx = dst_idx.expand(B, -1, -1)
679
+ else:
680
+ # scores = torch.cat([*scores][1:], dim = -1)
681
+ # node_max, node_idx = scores.max(dim=-1)
682
+ # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
683
+
684
+ # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
685
+ # src_idx = edge_idx[..., :r, :] # Merged Tokens
686
+ # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst
687
+
688
+ # unm_idx = unm_idx.expand(B, -1, -1)
689
+ # src_idx = src_idx.expand(B, -1, -1)
690
+ # dst_idx = dst_idx.expand(B, -1, -1)
691
+
692
+
693
+ # Find the most similar greedily
694
+ node_max, node_idx = scores.max(dim=-1)
695
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
696
+
697
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
698
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
699
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
700
+
701
+ # if adhere_src:
702
+ # unm_idx[:,...] = unm_idx[0:1]
703
+ # src_idx[:,...] = src_idx[0:1]
704
+ # dst_idx[:,...] = dst_idx[0:1]
705
+
706
+ def merge(x: torch.Tensor, mode=None, b_select = None) -> torch.Tensor:
707
+
708
+ src, dst = split(x)
709
+ n, t1, c = src.shape
710
+ if b_select is not None:
711
+ if not isinstance(b_select, list):
712
+ b_select = [b_select]
713
+ u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select]
714
+ else:
715
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
716
+
717
+ unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c))
718
+ # src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c))
719
+ mode = mode if mode is not None else merge_mode
720
+ if mode != "replace":
721
+ dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True)
722
+ # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add')
723
+
724
+ # dst_cnt = torch.ones_like(dst)
725
+ # src_ones = torch.ones_like(src)
726
+ # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add')
727
+
728
+ # dst = dst / dst_cnt
729
+ # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True)
730
+ # assert torch.allclose(dst1, dst2)
731
+
732
+ return torch.cat([unm, dst], dim=1)
733
+
734
+ def unmerge(x: torch.Tensor, b_select = None, unm_modi = None) -> torch.Tensor:
735
+
736
+
737
+
738
+ unm_len = unm_idx.shape[1]
739
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
740
+ b, _, c = unm.shape
741
+ if b_select is not None:
742
+ if not isinstance(b_select, list):
743
+ b_select = [b_select]
744
+ u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select]
745
+ else:
746
+ u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx
747
+ if unm_modi is not None:
748
+ if unm_modi == "zero":
749
+ unm = torch.zeros_like(unm)
750
+ src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c))
751
+
752
+ # Combine back to the original shape
753
+ out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype)
754
+ out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst)
755
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm)
756
+ out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src)
757
+
758
+
759
+ if unmerge_chunk == 0:
760
+ out = out[:,:src_len,:]
761
+ else:
762
+ out = out[:,src_len:,:]
763
+
764
+ return out
765
+
766
+ ret_dict = {"unm_num": unm_idx.shape[1]}
767
+ return merge, unmerge, ret_dict