liuhuadai commited on
Commit
aa27086
·
verified ·
1 Parent(s): d30e1a3

Update ldm/modules/encoders/modules.py

Browse files
Files changed (1) hide show
  1. ldm/modules/encoders/modules.py +582 -582
ldm/modules/encoders/modules.py CHANGED
@@ -1,582 +1,582 @@
1
- import torch
2
- import torch.nn as nn
3
- from functools import partial
4
- from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
5
- from torch.utils.checkpoint import checkpoint
6
- from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoTokenizer
7
- from importlib_resources import files
8
- from ldm.modules.encoders.CLAP.utils import read_config_as_args
9
- from ldm.modules.encoders.CLAP.clap import TextEncoder
10
- import copy
11
- from ldm.util import default, count_params
12
- import pytorch_lightning as pl
13
-
14
- class AbstractEncoder(pl.LightningModule):
15
- def __init__(self):
16
- super().__init__()
17
-
18
- def encode(self, *args, **kwargs):
19
- raise NotImplementedError
20
-
21
-
22
- class ClassEmbedder(nn.Module):
23
- def __init__(self, embed_dim, n_classes=1000, key='class'):
24
- super().__init__()
25
- self.key = key
26
- self.embedding = nn.Embedding(n_classes, embed_dim)
27
-
28
- def forward(self, batch, key=None):
29
- if key is None:
30
- key = self.key
31
- # this is for use in crossattn
32
- c = batch[key][:, None]# (bsz,1)
33
- c = self.embedding(c)
34
- return c
35
-
36
-
37
- class TransformerEmbedder(AbstractEncoder):
38
- """Some transformer encoder layers"""
39
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
40
- super().__init__()
41
- self.device = device
42
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
43
- attn_layers=Encoder(dim=n_embed, depth=n_layer))
44
-
45
- def forward(self, tokens):
46
- tokens = tokens.to(self.device) # meh
47
- z = self.transformer(tokens, return_embeddings=True)
48
- return z
49
-
50
- def encode(self, x):
51
- return self(x)
52
-
53
-
54
- class BERTTokenizer(AbstractEncoder):
55
- """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
56
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
57
- super().__init__()
58
- from transformers import BertTokenizerFast # TODO: add to reuquirements
59
- self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
60
- self.device = device
61
- self.vq_interface = vq_interface
62
- self.max_length = max_length
63
-
64
- def forward(self, text):
65
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67
- tokens = batch_encoding["input_ids"].to(self.device)
68
- return tokens
69
-
70
- @torch.no_grad()
71
- def encode(self, text):
72
- tokens = self(text)
73
- if not self.vq_interface:
74
- return tokens
75
- return None, None, [None, None, tokens]
76
-
77
- def decode(self, text):
78
- return text
79
-
80
-
81
- class BERTEmbedder(AbstractEncoder):# 这里不是用的pretrained bert,是用的transformers的BertTokenizer加自定义的TransformerWrapper
82
- """Uses the BERT tokenizr model and add some transformer encoder layers"""
83
- def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85
- super().__init__()
86
- self.use_tknz_fn = use_tokenizer
87
- if self.use_tknz_fn:
88
- self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89
- self.device = device
90
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91
- attn_layers=Encoder(dim=n_embed, depth=n_layer),
92
- emb_dropout=embedding_dropout)
93
-
94
- def forward(self, text):
95
- if self.use_tknz_fn:
96
- tokens = self.tknz_fn(text)#.to(self.device)
97
- else:
98
- tokens = text
99
- z = self.transformer(tokens, return_embeddings=True)
100
- return z
101
-
102
- def encode(self, text):
103
- # output of length 77
104
- return self(text)
105
-
106
-
107
- class SpatialRescaler(nn.Module):
108
- def __init__(self,
109
- n_stages=1,
110
- method='bilinear',
111
- multiplier=0.5,
112
- in_channels=3,
113
- out_channels=None,
114
- bias=False):
115
- super().__init__()
116
- self.n_stages = n_stages
117
- assert self.n_stages >= 0
118
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119
- self.multiplier = multiplier
120
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121
- self.remap_output = out_channels is not None
122
- if self.remap_output:
123
- print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125
-
126
- def forward(self,x):
127
- for stage in range(self.n_stages):
128
- x = self.interpolator(x, scale_factor=self.multiplier)
129
-
130
-
131
- if self.remap_output:
132
- x = self.channel_mapper(x)
133
- return x
134
-
135
- def encode(self, x):
136
- return self(x)
137
-
138
- def disabled_train(self, mode=True):
139
- """Overwrite model.train with this function to make sure train/eval mode
140
- does not change anymore."""
141
- return self
142
-
143
- class FrozenT5Embedder(AbstractEncoder):
144
- """Uses the T5 transformer encoder for text"""
145
- def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
146
- super().__init__()
147
- self.tokenizer = T5Tokenizer.from_pretrained(version)
148
- self.transformer = T5EncoderModel.from_pretrained(version)
149
- self.device = device
150
- self.max_length = max_length # TODO: typical value?
151
- if freeze:
152
- self.freeze()
153
-
154
- def freeze(self):
155
- self.transformer = self.transformer.eval()
156
- #self.train = disabled_train
157
- for param in self.parameters():
158
- param.requires_grad = False
159
-
160
- def forward(self, text):
161
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
162
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
163
- tokens = batch_encoding["input_ids"].to(self.device)
164
- outputs = self.transformer(input_ids=tokens)
165
-
166
- z = outputs.last_hidden_state
167
- return z
168
-
169
- def encode(self, text):
170
- return self(text)
171
-
172
- class FrozenFLANEmbedder(AbstractEncoder):
173
- """Uses the T5 transformer encoder for text"""
174
- def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
175
- super().__init__()
176
- self.tokenizer = T5Tokenizer.from_pretrained(version)
177
- self.transformer = T5EncoderModel.from_pretrained(version)
178
- self.device = device
179
- self.max_length = max_length # TODO: typical value?
180
- if freeze:
181
- self.freeze()
182
-
183
- def freeze(self):
184
- self.transformer = self.transformer.eval()
185
- #self.train = disabled_train
186
- for param in self.parameters():
187
- param.requires_grad = False
188
-
189
- def forward(self, text):
190
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
191
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
192
- tokens = batch_encoding["input_ids"].to(self.device)# tango的flanT5是不定长度的batch,这里做成定长的batch
193
- outputs = self.transformer(input_ids=tokens)
194
-
195
- z = outputs.last_hidden_state
196
- return z
197
-
198
- def encode(self, text):
199
- return self(text)
200
-
201
- class FrozenCLAPEmbedder(AbstractEncoder):
202
- """Uses the CLAP transformer encoder for text from microsoft"""
203
- def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
204
- super().__init__()
205
-
206
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
207
- match_params = dict()
208
- for key in list(model_state_dict.keys()):
209
- if 'caption_encoder' in key:
210
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
211
-
212
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
213
- args = read_config_as_args(config_as_str, is_config_str=True)
214
-
215
- # To device
216
- self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
217
- self.caption_encoder = TextEncoder(
218
- args.d_proj, args.text_model, args.transformer_embed_dim
219
- )
220
-
221
- self.max_length = max_length
222
- self.device = device
223
- if freeze: self.freeze()
224
-
225
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
226
-
227
- def freeze(self):# only freeze
228
- self.caption_encoder.base = self.caption_encoder.base.eval()
229
- for param in self.caption_encoder.base.parameters():
230
- param.requires_grad = False
231
-
232
-
233
- def encode(self, text):
234
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
235
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
236
- tokens = batch_encoding["input_ids"].to(self.device)
237
-
238
- outputs = self.caption_encoder.base(input_ids=tokens)
239
- z = self.caption_encoder.projection(outputs.last_hidden_state)
240
- return z
241
-
242
- class FrozenLAIONCLAPEmbedder(AbstractEncoder):
243
- """Uses the CLAP transformer encoder for text from LAION-AI"""
244
- def __init__(self, weights_path, freeze=True,sentence=False, device="cuda", max_length=77): # clip-vit-base-patch32
245
- super().__init__()
246
- # To device
247
- from transformers import RobertaTokenizer
248
- from ldm.modules.encoders.open_clap import create_model
249
- self.sentence = sentence
250
-
251
- model, model_cfg = create_model(
252
- 'HTSAT-tiny',
253
- 'roberta',
254
- weights_path,
255
- enable_fusion=True,
256
- fusion_type='aff_2d'
257
- )
258
-
259
- del model.audio_branch, model.audio_transform, model.audio_projection
260
- self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
261
- self.model = model
262
-
263
- self.max_length = max_length
264
- self.device = device
265
- self.to(self.device)
266
- if freeze: self.freeze()
267
-
268
- param_num = sum(p.numel() for p in model.parameters())
269
- print(f'{self.model.__class__.__name__} comes with: {param_num / 1e6:.3f} M params.')
270
-
271
- def to(self,device):
272
- self.model.to(device=device)
273
- self.device=device
274
-
275
- def freeze(self):
276
- self.model = self.model.eval()
277
- for param in self.model.parameters():
278
- param.requires_grad = False
279
-
280
- def encode(self, text):
281
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt").to(self.device)
282
- if self.sentence:
283
- z = self.model.get_text_embedding(batch_encoding).unsqueeze(1)
284
- else:
285
- # text_branch is roberta
286
- outputs = self.model.text_branch(input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=batch_encoding["attention_mask"].to(self.device))
287
- z = self.model.text_projection(outputs.last_hidden_state)
288
-
289
- return z
290
-
291
- class FrozenLAIONCLAPSetenceEmbedder(AbstractEncoder):
292
- """Uses the CLAP transformer encoder for text from LAION-AI"""
293
- def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
294
- super().__init__()
295
- # To device
296
- from transformers import RobertaTokenizer
297
- from ldm.modules.encoders.open_clap import create_model
298
-
299
-
300
- model, model_cfg = create_model(
301
- 'HTSAT-tiny',
302
- 'roberta',
303
- weights_path,
304
- enable_fusion=True,
305
- fusion_type='aff_2d'
306
- )
307
-
308
- del model.audio_branch, model.audio_transform, model.audio_projection
309
- self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
310
- self.model = model
311
-
312
- self.max_length = max_length
313
- self.device = device
314
- if freeze: self.freeze()
315
-
316
- param_num = sum(p.numel() for p in model.parameters())
317
- print(f'{self.model.__class__.__name__} comes with: {param_num / 1e+6:.3f} M params.')
318
-
319
- def freeze(self):
320
- self.model = self.model.eval()
321
- for param in self.model.parameters():
322
- param.requires_grad = False
323
-
324
- def tokenizer(self, text):
325
- result = self.tokenize(
326
- text,
327
- padding="max_length",
328
- truncation=True,
329
- max_length=512,
330
- return_tensors="pt",
331
- )
332
- return result
333
-
334
- def encode(self, text):
335
- with torch.no_grad():
336
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
337
- text_data = self.tokenizer(text)# input_ids shape:(b,512)
338
- embed = self.model.get_text_embedding(text_data)
339
- embed = embed.unsqueeze(1)# (b,1,512)
340
- return embed
341
-
342
- class FrozenCLAPOrderEmbedder2(AbstractEncoder):# 每个object后面都加上|
343
- """Uses the CLAP transformer encoder for text (from huggingface)"""
344
- def __init__(self, weights_path, freeze=True, device="cuda"):
345
- super().__init__()
346
-
347
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
348
- match_params = dict()
349
- for key in list(model_state_dict.keys()):
350
- if 'caption_encoder' in key:
351
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
352
-
353
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
354
- args = read_config_as_args(config_as_str, is_config_str=True)
355
-
356
- # To device
357
- self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
358
- self.caption_encoder = TextEncoder(
359
- args.d_proj, args.text_model, args.transformer_embed_dim
360
- ).to(device)
361
- self.max_objs = 10
362
- self.max_length = args.text_len
363
- self.device = device
364
- self.order_to_label = self.build_order_dict()
365
- if freeze: self.freeze()
366
-
367
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
368
-
369
- def freeze(self):
370
- self.caption_encoder.base = self.caption_encoder.base.eval()
371
- for param in self.caption_encoder.base.parameters():
372
- param.requires_grad = False
373
-
374
- def build_order_dict(self):
375
- order2label = {}
376
- num_orders = 10
377
- time_stamps = ['start','mid','end']
378
- time_num = len(time_stamps)
379
- for i in range(num_orders):
380
- for j,time_stamp in enumerate(time_stamps):
381
- order2label[f'order {i} {time_stamp}'] = i * time_num + j
382
- order2label['all'] = num_orders*len(time_stamps)
383
- order2label['unknown'] = num_orders*len(time_stamps) + 1
384
- return order2label
385
-
386
- def encode(self, text):
387
- obj_list,orders_list = [],[]
388
- for raw in text:
389
- splits = raw.split('@') # raw example: '<man speaking& order 1 start>@<man speaking& order 2 mid>@<idle engine& all>'
390
- objs = []
391
- orders = []
392
- for split in splits:# <obj& order>
393
- split = split[1:-1]
394
- obj,order = split.split('&')
395
- objs.append(obj.strip())
396
- try:
397
- orders.append(self.order_to_label[order.strip()])
398
- except:
399
- print(order.strip(),raw)
400
- assert len(objs) == len(orders)
401
- obj_list.append(' | '.join(objs)+' |')# '|' after every word
402
- orders_list.append(orders)
403
- batch_encoding = self.tokenizer(obj_list, truncation=True, max_length=self.max_length, return_length=True,
404
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
405
- tokens = batch_encoding["input_ids"]
406
-
407
- outputs = self.caption_encoder.base(input_ids=tokens.to(self.device))
408
- z = self.caption_encoder.projection(outputs.last_hidden_state)
409
- return {'token_embedding':z,'token_ids':tokens,'orders':orders_list}
410
-
411
- class FrozenCLAPOrderEmbedder3(AbstractEncoder):# 相比于FrozenCLAPOrderEmbedder2移除了projection,使用正确的max_len,去除了order仅保留时间。
412
- """Uses the CLAP transformer encoder for text (from huggingface)"""
413
- def __init__(self, weights_path, freeze=True, device="cuda"): # clip-vit-base-patch32
414
- super().__init__()
415
-
416
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
417
- match_params = dict()
418
- for key in list(model_state_dict.keys()):
419
- if 'caption_encoder' in key:
420
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
421
-
422
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
423
- args = read_config_as_args(config_as_str, is_config_str=True)
424
-
425
- # To device
426
- self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
427
- self.caption_encoder = TextEncoder(
428
- args.d_proj, args.text_model, args.transformer_embed_dim
429
- ).to(device)
430
- self.max_objs = 10
431
- self.max_length = args.text_len
432
- self.device = device
433
- self.order_to_label = self.build_order_dict()
434
- if freeze: self.freeze()
435
-
436
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
437
-
438
- def freeze(self):
439
- self.caption_encoder.base = self.caption_encoder.base.eval()
440
- for param in self.caption_encoder.base.parameters():
441
- param.requires_grad = False
442
-
443
- def build_order_dict(self):
444
- order2label = {}
445
- time_stamps = ['all','start','mid','end']
446
- for i,time_stamp in enumerate(time_stamps):
447
- order2label[time_stamp] = i
448
- return order2label
449
-
450
- def encode(self, text):
451
- obj_list,orders_list = [],[]
452
- for raw in text:
453
- splits = raw.split('@') # raw example: '<man speaking& order 1 start>@<man speaking& order 2 mid>@<idle engine& all>'
454
- objs = []
455
- orders = []
456
- for split in splits:# <obj& order>
457
- split = split[1:-1]
458
- obj,order = split.split('&')
459
- objs.append(obj.strip())
460
- try:
461
- orders.append(self.order_to_label[order.strip()])
462
- except:
463
- print(order.strip(),raw)
464
- assert len(objs) == len(orders)
465
- obj_list.append(' | '.join(objs)+' |')# '|' after every word
466
- orders_list.append(orders)
467
- batch_encoding = self.tokenizer(obj_list, truncation=True, max_length=self.max_length, return_length=True,
468
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
469
- tokens = batch_encoding["input_ids"]
470
- attn_mask = batch_encoding["attention_mask"]
471
- outputs = self.caption_encoder.base(input_ids=tokens.to(self.device))
472
- z = outputs.last_hidden_state
473
- return {'token_embedding':z,'token_ids':tokens,'orders':orders_list,'attn_mask':attn_mask}
474
-
475
- class FrozenCLAPT5Embedder(AbstractEncoder):
476
- """Uses the CLAP transformer encoder for text from microsoft"""
477
- def __init__(self, weights_path,t5version="google/flan-t5-large", freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
478
- super().__init__()
479
-
480
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
481
- match_params = dict()
482
- for key in list(model_state_dict.keys()):
483
- if 'caption_encoder' in key:
484
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
485
-
486
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
487
- args = read_config_as_args(config_as_str, is_config_str=True)
488
-
489
- self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
490
- self.caption_encoder = TextEncoder(
491
- args.d_proj, args.text_model, args.transformer_embed_dim
492
- )
493
-
494
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
495
- self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
496
-
497
- self.max_length = max_length
498
- self.to(device=device)
499
- if freeze: self.freeze()
500
-
501
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
502
-
503
- def freeze(self):
504
- self.caption_encoder = self.caption_encoder.eval()
505
- for param in self.caption_encoder.parameters():
506
- param.requires_grad = False
507
-
508
- def to(self,device):
509
- self.t5_transformer.to(device)
510
- self.caption_encoder.to(device)
511
- self.device = device
512
-
513
- def encode(self, text):
514
- ori_caption = text['ori_caption']
515
- struct_caption = text['struct_caption']
516
- # print(ori_caption,struct_caption)
517
- clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length, return_length=True,
518
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
519
- ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
520
- t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length, return_length=True,
521
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
522
- struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
523
- outputs = self.caption_encoder.base(input_ids=ori_tokens)
524
- z = self.caption_encoder.projection(outputs.last_hidden_state)
525
- z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
526
- return torch.concat([z,z2],dim=1)
527
-
528
-
529
- class FrozenCLAPFLANEmbedder(AbstractEncoder):
530
- """Uses the CLAP transformer encoder for text from microsoft"""
531
- def __init__(self, weights_path,t5version="../ldm/modules/encoders/CLAP/t5-v1_1-large", freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
532
- super().__init__()
533
-
534
- model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
535
- match_params = dict()
536
- for key in list(model_state_dict.keys()):
537
- if 'caption_encoder' in key:
538
- match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
539
-
540
- config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yaml').read_text()
541
- args = read_config_as_args(config_as_str, is_config_str=True)
542
-
543
- self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
544
- self.caption_encoder = TextEncoder(
545
- args.d_proj, args.text_model, args.transformer_embed_dim
546
- )
547
-
548
- self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
549
- self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
550
-
551
- self.max_length = max_length
552
- # self.to(device=device)
553
- if freeze: self.freeze()
554
-
555
- print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
556
-
557
- def freeze(self):
558
- self.caption_encoder = self.caption_encoder.eval()
559
- for param in self.caption_encoder.parameters():
560
- param.requires_grad = False
561
-
562
- def to(self,device):
563
- self.t5_transformer.to(device)
564
- self.caption_encoder.to(device)
565
- self.device = device
566
-
567
- def encode(self, text):
568
- ori_caption = text['ori_caption']
569
- struct_caption = text['struct_caption']
570
- # print(ori_caption,struct_caption)
571
- clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length, return_length=True,
572
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
573
- ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
574
- t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length, return_length=True,
575
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
576
- struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
577
- # if self.caption_encoder.device != ori_tokens.device:
578
- # self.to(self.device)
579
- outputs = self.caption_encoder.base(input_ids=ori_tokens)
580
- z = self.caption_encoder.projection(outputs.last_hidden_state)
581
- z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
582
- return torch.concat([z,z2],dim=1)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
5
+ from torch.utils.checkpoint import checkpoint
6
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoTokenizer
7
+ from importlib_resources import files
8
+ from ldm.modules.encoders.CLAP.utils import read_config_as_args
9
+ from ldm.modules.encoders.CLAP.clap import TextEncoder
10
+ import copy
11
+ from ldm.util import default, count_params
12
+ import pytorch_lightning as pl
13
+
14
+ class AbstractEncoder(pl.LightningModule):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def encode(self, *args, **kwargs):
19
+ raise NotImplementedError
20
+
21
+
22
+ class ClassEmbedder(nn.Module):
23
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
24
+ super().__init__()
25
+ self.key = key
26
+ self.embedding = nn.Embedding(n_classes, embed_dim)
27
+
28
+ def forward(self, batch, key=None):
29
+ if key is None:
30
+ key = self.key
31
+ # this is for use in crossattn
32
+ c = batch[key][:, None]# (bsz,1)
33
+ c = self.embedding(c)
34
+ return c
35
+
36
+
37
+ class TransformerEmbedder(AbstractEncoder):
38
+ """Some transformer encoder layers"""
39
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
40
+ super().__init__()
41
+ self.device = device
42
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
43
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
44
+
45
+ def forward(self, tokens):
46
+ tokens = tokens.to(self.device) # meh
47
+ z = self.transformer(tokens, return_embeddings=True)
48
+ return z
49
+
50
+ def encode(self, x):
51
+ return self(x)
52
+
53
+
54
+ class BERTTokenizer(AbstractEncoder):
55
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
56
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
57
+ super().__init__()
58
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
59
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
60
+ self.device = device
61
+ self.vq_interface = vq_interface
62
+ self.max_length = max_length
63
+
64
+ def forward(self, text):
65
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67
+ tokens = batch_encoding["input_ids"].to(self.device)
68
+ return tokens
69
+
70
+ @torch.no_grad()
71
+ def encode(self, text):
72
+ tokens = self(text)
73
+ if not self.vq_interface:
74
+ return tokens
75
+ return None, None, [None, None, tokens]
76
+
77
+ def decode(self, text):
78
+ return text
79
+
80
+
81
+ class BERTEmbedder(AbstractEncoder):# 这里不是用的pretrained bert,是用的transformers的BertTokenizer加自定义的TransformerWrapper
82
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
83
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85
+ super().__init__()
86
+ self.use_tknz_fn = use_tokenizer
87
+ if self.use_tknz_fn:
88
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89
+ self.device = device
90
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
92
+ emb_dropout=embedding_dropout)
93
+
94
+ def forward(self, text):
95
+ if self.use_tknz_fn:
96
+ tokens = self.tknz_fn(text)#.to(self.device)
97
+ else:
98
+ tokens = text
99
+ z = self.transformer(tokens, return_embeddings=True)
100
+ return z
101
+
102
+ def encode(self, text):
103
+ # output of length 77
104
+ return self(text)
105
+
106
+
107
+ class SpatialRescaler(nn.Module):
108
+ def __init__(self,
109
+ n_stages=1,
110
+ method='bilinear',
111
+ multiplier=0.5,
112
+ in_channels=3,
113
+ out_channels=None,
114
+ bias=False):
115
+ super().__init__()
116
+ self.n_stages = n_stages
117
+ assert self.n_stages >= 0
118
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119
+ self.multiplier = multiplier
120
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121
+ self.remap_output = out_channels is not None
122
+ if self.remap_output:
123
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125
+
126
+ def forward(self,x):
127
+ for stage in range(self.n_stages):
128
+ x = self.interpolator(x, scale_factor=self.multiplier)
129
+
130
+
131
+ if self.remap_output:
132
+ x = self.channel_mapper(x)
133
+ return x
134
+
135
+ def encode(self, x):
136
+ return self(x)
137
+
138
+ def disabled_train(self, mode=True):
139
+ """Overwrite model.train with this function to make sure train/eval mode
140
+ does not change anymore."""
141
+ return self
142
+
143
+ class FrozenT5Embedder(AbstractEncoder):
144
+ """Uses the T5 transformer encoder for text"""
145
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
146
+ super().__init__()
147
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
148
+ self.transformer = T5EncoderModel.from_pretrained(version)
149
+ self.device = device
150
+ self.max_length = max_length # TODO: typical value?
151
+ if freeze:
152
+ self.freeze()
153
+
154
+ def freeze(self):
155
+ self.transformer = self.transformer.eval()
156
+ #self.train = disabled_train
157
+ for param in self.parameters():
158
+ param.requires_grad = False
159
+
160
+ def forward(self, text):
161
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
162
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
163
+ tokens = batch_encoding["input_ids"].to(self.device)
164
+ outputs = self.transformer(input_ids=tokens)
165
+
166
+ z = outputs.last_hidden_state
167
+ return z
168
+
169
+ def encode(self, text):
170
+ return self(text)
171
+
172
+ class FrozenFLANEmbedder(AbstractEncoder):
173
+ """Uses the T5 transformer encoder for text"""
174
+ def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
175
+ super().__init__()
176
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
177
+ self.transformer = T5EncoderModel.from_pretrained(version)
178
+ self.device = device
179
+ self.max_length = max_length # TODO: typical value?
180
+ if freeze:
181
+ self.freeze()
182
+
183
+ def freeze(self):
184
+ self.transformer = self.transformer.eval()
185
+ #self.train = disabled_train
186
+ for param in self.parameters():
187
+ param.requires_grad = False
188
+
189
+ def forward(self, text):
190
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
191
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
192
+ tokens = batch_encoding["input_ids"].to(self.device)# tango的flanT5是不定长度的batch,这里做成定长的batch
193
+ outputs = self.transformer(input_ids=tokens)
194
+
195
+ z = outputs.last_hidden_state
196
+ return z
197
+
198
+ def encode(self, text):
199
+ return self(text)
200
+
201
+ class FrozenCLAPEmbedder(AbstractEncoder):
202
+ """Uses the CLAP transformer encoder for text from microsoft"""
203
+ def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
204
+ super().__init__()
205
+
206
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
207
+ match_params = dict()
208
+ for key in list(model_state_dict.keys()):
209
+ if 'caption_encoder' in key:
210
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
211
+
212
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
213
+ args = read_config_as_args(config_as_str, is_config_str=True)
214
+
215
+ # To device
216
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
217
+ self.caption_encoder = TextEncoder(
218
+ args.d_proj, args.text_model, args.transformer_embed_dim
219
+ )
220
+
221
+ self.max_length = max_length
222
+ self.device = device
223
+ if freeze: self.freeze()
224
+
225
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
226
+
227
+ def freeze(self):# only freeze
228
+ self.caption_encoder.base = self.caption_encoder.base.eval()
229
+ for param in self.caption_encoder.base.parameters():
230
+ param.requires_grad = False
231
+
232
+
233
+ def encode(self, text):
234
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
235
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
236
+ tokens = batch_encoding["input_ids"].to(self.device)
237
+
238
+ outputs = self.caption_encoder.base(input_ids=tokens)
239
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
240
+ return z
241
+
242
+ class FrozenLAIONCLAPEmbedder(AbstractEncoder):
243
+ """Uses the CLAP transformer encoder for text from LAION-AI"""
244
+ def __init__(self, weights_path, freeze=True,sentence=False, device="cuda", max_length=77): # clip-vit-base-patch32
245
+ super().__init__()
246
+ # To device
247
+ from transformers import RobertaTokenizer
248
+ from ldm.modules.encoders.open_clap import create_model
249
+ self.sentence = sentence
250
+
251
+ model, model_cfg = create_model(
252
+ 'HTSAT-tiny',
253
+ 'roberta',
254
+ weights_path,
255
+ enable_fusion=True,
256
+ fusion_type='aff_2d'
257
+ )
258
+
259
+ del model.audio_branch, model.audio_transform, model.audio_projection
260
+ self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
261
+ self.model = model
262
+
263
+ self.max_length = max_length
264
+ self.device = device
265
+ self.to(self.device)
266
+ if freeze: self.freeze()
267
+
268
+ param_num = sum(p.numel() for p in model.parameters())
269
+ print(f'{self.model.__class__.__name__} comes with: {param_num / 1e6:.3f} M params.')
270
+
271
+ def to(self,device):
272
+ self.model.to(device=device)
273
+ self.device=device
274
+
275
+ def freeze(self):
276
+ self.model = self.model.eval()
277
+ for param in self.model.parameters():
278
+ param.requires_grad = False
279
+
280
+ def encode(self, text):
281
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt").to(self.device)
282
+ if self.sentence:
283
+ z = self.model.get_text_embedding(batch_encoding).unsqueeze(1)
284
+ else:
285
+ # text_branch is roberta
286
+ outputs = self.model.text_branch(input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=batch_encoding["attention_mask"].to(self.device))
287
+ z = self.model.text_projection(outputs.last_hidden_state)
288
+
289
+ return z
290
+
291
+ class FrozenLAIONCLAPSetenceEmbedder(AbstractEncoder):
292
+ """Uses the CLAP transformer encoder for text from LAION-AI"""
293
+ def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
294
+ super().__init__()
295
+ # To device
296
+ from transformers import RobertaTokenizer
297
+ from ldm.modules.encoders.open_clap import create_model
298
+
299
+
300
+ model, model_cfg = create_model(
301
+ 'HTSAT-tiny',
302
+ 'roberta',
303
+ weights_path,
304
+ enable_fusion=True,
305
+ fusion_type='aff_2d'
306
+ )
307
+
308
+ del model.audio_branch, model.audio_transform, model.audio_projection
309
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
310
+ self.model = model
311
+
312
+ self.max_length = max_length
313
+ self.device = device
314
+ if freeze: self.freeze()
315
+
316
+ param_num = sum(p.numel() for p in model.parameters())
317
+ print(f'{self.model.__class__.__name__} comes with: {param_num / 1e+6:.3f} M params.')
318
+
319
+ def freeze(self):
320
+ self.model = self.model.eval()
321
+ for param in self.model.parameters():
322
+ param.requires_grad = False
323
+
324
+ def tokenizer(self, text):
325
+ result = self.tokenize(
326
+ text,
327
+ padding="max_length",
328
+ truncation=True,
329
+ max_length=512,
330
+ return_tensors="pt",
331
+ )
332
+ return result
333
+
334
+ def encode(self, text):
335
+ with torch.no_grad():
336
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
337
+ text_data = self.tokenizer(text)# input_ids shape:(b,512)
338
+ embed = self.model.get_text_embedding(text_data)
339
+ embed = embed.unsqueeze(1)# (b,1,512)
340
+ return embed
341
+
342
+ class FrozenCLAPOrderEmbedder2(AbstractEncoder):# 每个object后面都加上|
343
+ """Uses the CLAP transformer encoder for text (from huggingface)"""
344
+ def __init__(self, weights_path, freeze=True, device="cuda"):
345
+ super().__init__()
346
+
347
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
348
+ match_params = dict()
349
+ for key in list(model_state_dict.keys()):
350
+ if 'caption_encoder' in key:
351
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
352
+
353
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
354
+ args = read_config_as_args(config_as_str, is_config_str=True)
355
+
356
+ # To device
357
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
358
+ self.caption_encoder = TextEncoder(
359
+ args.d_proj, args.text_model, args.transformer_embed_dim
360
+ ).to(device)
361
+ self.max_objs = 10
362
+ self.max_length = args.text_len
363
+ self.device = device
364
+ self.order_to_label = self.build_order_dict()
365
+ if freeze: self.freeze()
366
+
367
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
368
+
369
+ def freeze(self):
370
+ self.caption_encoder.base = self.caption_encoder.base.eval()
371
+ for param in self.caption_encoder.base.parameters():
372
+ param.requires_grad = False
373
+
374
+ def build_order_dict(self):
375
+ order2label = {}
376
+ num_orders = 10
377
+ time_stamps = ['start','mid','end']
378
+ time_num = len(time_stamps)
379
+ for i in range(num_orders):
380
+ for j,time_stamp in enumerate(time_stamps):
381
+ order2label[f'order {i} {time_stamp}'] = i * time_num + j
382
+ order2label['all'] = num_orders*len(time_stamps)
383
+ order2label['unknown'] = num_orders*len(time_stamps) + 1
384
+ return order2label
385
+
386
+ def encode(self, text):
387
+ obj_list,orders_list = [],[]
388
+ for raw in text:
389
+ splits = raw.split('@') # raw example: '<man speaking& order 1 start>@<man speaking& order 2 mid>@<idle engine& all>'
390
+ objs = []
391
+ orders = []
392
+ for split in splits:# <obj& order>
393
+ split = split[1:-1]
394
+ obj,order = split.split('&')
395
+ objs.append(obj.strip())
396
+ try:
397
+ orders.append(self.order_to_label[order.strip()])
398
+ except:
399
+ print(order.strip(),raw)
400
+ assert len(objs) == len(orders)
401
+ obj_list.append(' | '.join(objs)+' |')# '|' after every word
402
+ orders_list.append(orders)
403
+ batch_encoding = self.tokenizer(obj_list, truncation=True, max_length=self.max_length, return_length=True,
404
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
405
+ tokens = batch_encoding["input_ids"]
406
+
407
+ outputs = self.caption_encoder.base(input_ids=tokens.to(self.device))
408
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
409
+ return {'token_embedding':z,'token_ids':tokens,'orders':orders_list}
410
+
411
+ class FrozenCLAPOrderEmbedder3(AbstractEncoder):# 相比于FrozenCLAPOrderEmbedder2移除了projection,使用正确的max_len,去除了order仅保留时间。
412
+ """Uses the CLAP transformer encoder for text (from huggingface)"""
413
+ def __init__(self, weights_path, freeze=True, device="cuda"): # clip-vit-base-patch32
414
+ super().__init__()
415
+
416
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
417
+ match_params = dict()
418
+ for key in list(model_state_dict.keys()):
419
+ if 'caption_encoder' in key:
420
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
421
+
422
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
423
+ args = read_config_as_args(config_as_str, is_config_str=True)
424
+
425
+ # To device
426
+ self.tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
427
+ self.caption_encoder = TextEncoder(
428
+ args.d_proj, args.text_model, args.transformer_embed_dim
429
+ ).to(device)
430
+ self.max_objs = 10
431
+ self.max_length = args.text_len
432
+ self.device = device
433
+ self.order_to_label = self.build_order_dict()
434
+ if freeze: self.freeze()
435
+
436
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
437
+
438
+ def freeze(self):
439
+ self.caption_encoder.base = self.caption_encoder.base.eval()
440
+ for param in self.caption_encoder.base.parameters():
441
+ param.requires_grad = False
442
+
443
+ def build_order_dict(self):
444
+ order2label = {}
445
+ time_stamps = ['all','start','mid','end']
446
+ for i,time_stamp in enumerate(time_stamps):
447
+ order2label[time_stamp] = i
448
+ return order2label
449
+
450
+ def encode(self, text):
451
+ obj_list,orders_list = [],[]
452
+ for raw in text:
453
+ splits = raw.split('@') # raw example: '<man speaking& order 1 start>@<man speaking& order 2 mid>@<idle engine& all>'
454
+ objs = []
455
+ orders = []
456
+ for split in splits:# <obj& order>
457
+ split = split[1:-1]
458
+ obj,order = split.split('&')
459
+ objs.append(obj.strip())
460
+ try:
461
+ orders.append(self.order_to_label[order.strip()])
462
+ except:
463
+ print(order.strip(),raw)
464
+ assert len(objs) == len(orders)
465
+ obj_list.append(' | '.join(objs)+' |')# '|' after every word
466
+ orders_list.append(orders)
467
+ batch_encoding = self.tokenizer(obj_list, truncation=True, max_length=self.max_length, return_length=True,
468
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
469
+ tokens = batch_encoding["input_ids"]
470
+ attn_mask = batch_encoding["attention_mask"]
471
+ outputs = self.caption_encoder.base(input_ids=tokens.to(self.device))
472
+ z = outputs.last_hidden_state
473
+ return {'token_embedding':z,'token_ids':tokens,'orders':orders_list,'attn_mask':attn_mask}
474
+
475
+ class FrozenCLAPT5Embedder(AbstractEncoder):
476
+ """Uses the CLAP transformer encoder for text from microsoft"""
477
+ def __init__(self, weights_path,t5version="google/flan-t5-large", freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
478
+ super().__init__()
479
+
480
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
481
+ match_params = dict()
482
+ for key in list(model_state_dict.keys()):
483
+ if 'caption_encoder' in key:
484
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
485
+
486
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yml').read_text()
487
+ args = read_config_as_args(config_as_str, is_config_str=True)
488
+
489
+ self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
490
+ self.caption_encoder = TextEncoder(
491
+ args.d_proj, args.text_model, args.transformer_embed_dim
492
+ )
493
+
494
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
495
+ self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
496
+
497
+ self.max_length = max_length
498
+ self.to(device=device)
499
+ if freeze: self.freeze()
500
+
501
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
502
+
503
+ def freeze(self):
504
+ self.caption_encoder = self.caption_encoder.eval()
505
+ for param in self.caption_encoder.parameters():
506
+ param.requires_grad = False
507
+
508
+ def to(self,device):
509
+ self.t5_transformer.to(device)
510
+ self.caption_encoder.to(device)
511
+ self.device = device
512
+
513
+ def encode(self, text):
514
+ ori_caption = text['ori_caption']
515
+ struct_caption = text['struct_caption']
516
+ # print(ori_caption,struct_caption)
517
+ clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length, return_length=True,
518
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
519
+ ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
520
+ t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length, return_length=True,
521
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
522
+ struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
523
+ outputs = self.caption_encoder.base(input_ids=ori_tokens)
524
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
525
+ z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
526
+ return torch.concat([z,z2],dim=1)
527
+
528
+
529
+ class FrozenCLAPFLANEmbedder(AbstractEncoder):
530
+ """Uses the CLAP transformer encoder for text from microsoft"""
531
+ def __init__(self, weights_path,t5version="ldm/modules/encoders/CLAP/t5-v1_1-large", freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
532
+ super().__init__()
533
+
534
+ model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
535
+ match_params = dict()
536
+ for key in list(model_state_dict.keys()):
537
+ if 'caption_encoder' in key:
538
+ match_params[key.replace('caption_encoder.', '')] = model_state_dict[key]
539
+
540
+ config_as_str = files('ldm').joinpath('modules/encoders/CLAP/config.yaml').read_text()
541
+ args = read_config_as_args(config_as_str, is_config_str=True)
542
+
543
+ self.clap_tokenizer = AutoTokenizer.from_pretrained(args.text_model) # args.text_model
544
+ self.caption_encoder = TextEncoder(
545
+ args.d_proj, args.text_model, args.transformer_embed_dim
546
+ )
547
+
548
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5version)
549
+ self.t5_transformer = T5EncoderModel.from_pretrained(t5version)
550
+
551
+ self.max_length = max_length
552
+ # self.to(device=device)
553
+ if freeze: self.freeze()
554
+
555
+ print(f"{self.caption_encoder.__class__.__name__} comes with {count_params(self.caption_encoder) * 1.e-6:.2f} M params.")
556
+
557
+ def freeze(self):
558
+ self.caption_encoder = self.caption_encoder.eval()
559
+ for param in self.caption_encoder.parameters():
560
+ param.requires_grad = False
561
+
562
+ def to(self,device):
563
+ self.t5_transformer.to(device)
564
+ self.caption_encoder.to(device)
565
+ self.device = device
566
+
567
+ def encode(self, text):
568
+ ori_caption = text['ori_caption']
569
+ struct_caption = text['struct_caption']
570
+ # print(ori_caption,struct_caption)
571
+ clap_batch_encoding = self.clap_tokenizer(ori_caption, truncation=True, max_length=self.max_length, return_length=True,
572
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
573
+ ori_tokens = clap_batch_encoding["input_ids"].to(self.device)
574
+ t5_batch_encoding = self.t5_tokenizer(struct_caption, truncation=True, max_length=self.max_length, return_length=True,
575
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
576
+ struct_tokens = t5_batch_encoding["input_ids"].to(self.device)
577
+ # if self.caption_encoder.device != ori_tokens.device:
578
+ # self.to(self.device)
579
+ outputs = self.caption_encoder.base(input_ids=ori_tokens)
580
+ z = self.caption_encoder.projection(outputs.last_hidden_state)
581
+ z2 = self.t5_transformer(input_ids=struct_tokens).last_hidden_state
582
+ return torch.concat([z,z2],dim=1)