chaojiemao commited on
Commit
8fc110a
·
verified ·
1 Parent(s): bf6a8e1

Update model/flux.py

Browse files
Files changed (1) hide show
  1. model/flux.py +378 -1
model/flux.py CHANGED
@@ -1,7 +1,10 @@
 
 
1
  import math
2
  from collections import OrderedDict
3
  from functools import partial
4
-
 
5
  import torch
6
  from einops import rearrange, repeat
7
  from scepter.modules.model.base_model import BaseModel
@@ -12,11 +15,385 @@ from scepter.modules.utils.file_system import FS
12
  from torch import Tensor, nn
13
  from torch.nn.utils.rnn import pad_sequence
14
  from torch.utils.checkpoint import checkpoint_sequential
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
17
  MLPEmbedder, SingleStreamBlock,
18
  timestep_embedding)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @BACKBONES.register_class()
21
  class Flux(BaseModel):
22
  """
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
  import math
4
  from collections import OrderedDict
5
  from functools import partial
6
+ import warnings
7
+ from contextlib import nullcontext
8
  import torch
9
  from einops import rearrange, repeat
10
  from scepter.modules.model.base_model import BaseModel
 
15
  from torch import Tensor, nn
16
  from torch.nn.utils.rnn import pad_sequence
17
  from torch.utils.checkpoint import checkpoint_sequential
18
+ import torch.nn.functional as F
19
+ import torch.utils.dlpack
20
+ import transformers
21
+ from scepter.modules.model.embedder.base_embedder import BaseEmbedder
22
+ from scepter.modules.model.registry import EMBEDDERS
23
+ from scepter.modules.model.tokenizer.tokenizer_component import (
24
+ basic_clean, canonicalize, heavy_clean, whitespace_clean)
25
+ try:
26
+ from transformers import AutoTokenizer, T5EncoderModel
27
+ except Exception as e:
28
+ warnings.warn(
29
+ f'Import transformers error, please deal with this problem: {e}')
30
 
31
  from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
32
  MLPEmbedder, SingleStreamBlock,
33
  timestep_embedding)
34
 
35
+
36
+
37
+ @EMBEDDERS.register_class()
38
+ class ACETextEmbedder(BaseEmbedder):
39
+ """
40
+ Uses the OpenCLIP transformer encoder for text
41
+ """
42
+ """
43
+ Uses the OpenCLIP transformer encoder for text
44
+ """
45
+ para_dict = {
46
+ 'PRETRAINED_MODEL': {
47
+ 'value':
48
+ 'google/umt5-small',
49
+ 'description':
50
+ 'Pretrained Model for umt5, modelcard path or local path.'
51
+ },
52
+ 'TOKENIZER_PATH': {
53
+ 'value': 'google/umt5-small',
54
+ 'description':
55
+ 'Tokenizer Path for umt5, modelcard path or local path.'
56
+ },
57
+ 'FREEZE': {
58
+ 'value': True,
59
+ 'description': ''
60
+ },
61
+ 'USE_GRAD': {
62
+ 'value': False,
63
+ 'description': 'Compute grad or not.'
64
+ },
65
+ 'CLEAN': {
66
+ 'value':
67
+ 'whitespace',
68
+ 'description':
69
+ 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
70
+ },
71
+ 'LAYER': {
72
+ 'value': 'last',
73
+ 'description': ''
74
+ },
75
+ 'LEGACY': {
76
+ 'value':
77
+ True,
78
+ 'description':
79
+ 'Whether use legacy returnd feature or not ,default True.'
80
+ }
81
+ }
82
+
83
+ def __init__(self, cfg, logger=None):
84
+ super().__init__(cfg, logger=logger)
85
+ pretrained_path = cfg.get('PRETRAINED_MODEL', None)
86
+ self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
87
+ assert pretrained_path
88
+ with FS.get_dir_to_local_dir(pretrained_path,
89
+ wait_finish=True) as local_path:
90
+ self.model = T5EncoderModel.from_pretrained(
91
+ local_path,
92
+ torch_dtype=getattr(
93
+ torch,
94
+ 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
95
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
96
+ self.length = cfg.get('LENGTH', 77)
97
+
98
+ self.use_grad = cfg.get('USE_GRAD', False)
99
+ self.clean = cfg.get('CLEAN', 'whitespace')
100
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
101
+ if tokenizer_path:
102
+ self.tokenize_kargs = {'return_tensors': 'pt'}
103
+ with FS.get_dir_to_local_dir(tokenizer_path,
104
+ wait_finish=True) as local_path:
105
+ if self.added_identifier is not None and isinstance(
106
+ self.added_identifier, list):
107
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
108
+ else:
109
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
110
+ if self.length is not None:
111
+ self.tokenize_kargs.update({
112
+ 'padding': 'max_length',
113
+ 'truncation': True,
114
+ 'max_length': self.length
115
+ })
116
+ self.eos_token = self.tokenizer(
117
+ self.tokenizer.eos_token)['input_ids'][0]
118
+ else:
119
+ self.tokenizer = None
120
+ self.tokenize_kargs = {}
121
+
122
+ self.use_grad = cfg.get('USE_GRAD', False)
123
+ self.clean = cfg.get('CLEAN', 'whitespace')
124
+
125
+ def freeze(self):
126
+ self.model = self.model.eval()
127
+ for param in self.parameters():
128
+ param.requires_grad = False
129
+
130
+ # encode && encode_text
131
+ def forward(self, tokens, return_mask=False, use_mask=True):
132
+ # tokenization
133
+ embedding_context = nullcontext if self.use_grad else torch.no_grad
134
+ with embedding_context():
135
+ if use_mask:
136
+ x = self.model(tokens.input_ids.to(we.device_id),
137
+ tokens.attention_mask.to(we.device_id))
138
+ else:
139
+ x = self.model(tokens.input_ids.to(we.device_id))
140
+ x = x.last_hidden_state
141
+
142
+ if return_mask:
143
+ return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
144
+ else:
145
+ return x.detach() + 0.0, None
146
+
147
+ def _clean(self, text):
148
+ if self.clean == 'whitespace':
149
+ text = whitespace_clean(basic_clean(text))
150
+ elif self.clean == 'lower':
151
+ text = whitespace_clean(basic_clean(text)).lower()
152
+ elif self.clean == 'canonicalize':
153
+ text = canonicalize(basic_clean(text))
154
+ elif self.clean == 'heavy':
155
+ text = heavy_clean(basic_clean(text))
156
+ return text
157
+
158
+ def encode(self, text, return_mask=False, use_mask=True):
159
+ if isinstance(text, str):
160
+ text = [text]
161
+ if self.clean:
162
+ text = [self._clean(u) for u in text]
163
+ assert self.tokenizer is not None
164
+ cont, mask = [], []
165
+ with torch.autocast(device_type='cuda',
166
+ enabled=self.t5_dtype in ('float16', 'bfloat16'),
167
+ dtype=getattr(torch, self.t5_dtype)):
168
+ for tt in text:
169
+ tokens = self.tokenizer([tt], **self.tokenize_kargs)
170
+ one_cont, one_mask = self(tokens,
171
+ return_mask=return_mask,
172
+ use_mask=use_mask)
173
+ cont.append(one_cont)
174
+ mask.append(one_mask)
175
+ if return_mask:
176
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
177
+ else:
178
+ return torch.cat(cont, dim=0)
179
+
180
+ def encode_list(self, text_list, return_mask=True):
181
+ cont_list = []
182
+ mask_list = []
183
+ for pp in text_list:
184
+ cont, cont_mask = self.encode(pp, return_mask=return_mask)
185
+ cont_list.append(cont)
186
+ mask_list.append(cont_mask)
187
+ if return_mask:
188
+ return cont_list, mask_list
189
+ else:
190
+ return cont_list
191
+
192
+ @staticmethod
193
+ def get_config_template():
194
+ return dict_to_yaml('MODELS',
195
+ __class__.__name__,
196
+ ACETextEmbedder.para_dict,
197
+ set_name=True)
198
+
199
+ @EMBEDDERS.register_class()
200
+ class ACEHFEmbedder(BaseEmbedder):
201
+ para_dict = {
202
+ "HF_MODEL_CLS": {
203
+ "value": None,
204
+ "description": "huggingface cls in transfomer"
205
+ },
206
+ "MODEL_PATH": {
207
+ "value": None,
208
+ "description": "model folder path"
209
+ },
210
+ "HF_TOKENIZER_CLS": {
211
+ "value": None,
212
+ "description": "huggingface cls in transfomer"
213
+ },
214
+
215
+ "TOKENIZER_PATH": {
216
+ "value": None,
217
+ "description": "tokenizer folder path"
218
+ },
219
+ "MAX_LENGTH": {
220
+ "value": 77,
221
+ "description": "max length of input"
222
+ },
223
+ "OUTPUT_KEY": {
224
+ "value": "last_hidden_state",
225
+ "description": "output key"
226
+ },
227
+ "D_TYPE": {
228
+ "value": "float",
229
+ "description": "dtype"
230
+ },
231
+ "BATCH_INFER": {
232
+ "value": False,
233
+ "description": "batch infer"
234
+ }
235
+ }
236
+ para_dict.update(BaseEmbedder.para_dict)
237
+ def __init__(self, cfg, logger=None):
238
+ super().__init__(cfg, logger=logger)
239
+ hf_model_cls = cfg.get('HF_MODEL_CLS', None)
240
+ model_path = cfg.get("MODEL_PATH", None)
241
+ hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
242
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
243
+ self.max_length = cfg.get('MAX_LENGTH', 77)
244
+ self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
245
+ self.d_type = cfg.get("D_TYPE", "float")
246
+ self.clean = cfg.get("CLEAN", "whitespace")
247
+ self.batch_infer = cfg.get("BATCH_INFER", False)
248
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
249
+ torch_dtype = getattr(torch, self.d_type)
250
+
251
+ assert hf_model_cls is not None and hf_tokenizer_cls is not None
252
+ assert model_path is not None and tokenizer_path is not None
253
+ with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
254
+ self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
255
+ max_length = self.max_length,
256
+ torch_dtype = torch_dtype,
257
+ additional_special_tokens=self.added_identifier)
258
+
259
+ with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
260
+ self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
261
+
262
+
263
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
264
+
265
+ def forward(self, text: list[str], return_mask = False):
266
+ batch_encoding = self.tokenizer(
267
+ text,
268
+ truncation=True,
269
+ max_length=self.max_length,
270
+ return_length=False,
271
+ return_overflowing_tokens=False,
272
+ padding="max_length",
273
+ return_tensors="pt",
274
+ )
275
+
276
+ outputs = self.hf_module(
277
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
278
+ attention_mask=None,
279
+ output_hidden_states=False,
280
+ )
281
+ if return_mask:
282
+ return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
283
+ else:
284
+ return outputs[self.output_key], None
285
+
286
+ def encode(self, text, return_mask = False):
287
+ if isinstance(text, str):
288
+ text = [text]
289
+ if self.clean:
290
+ text = [self._clean(u) for u in text]
291
+ if not self.batch_infer:
292
+ cont, mask = [], []
293
+ for tt in text:
294
+ one_cont, one_mask = self([tt], return_mask=return_mask)
295
+ cont.append(one_cont)
296
+ mask.append(one_mask)
297
+ if return_mask:
298
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
299
+ else:
300
+ return torch.cat(cont, dim=0)
301
+ else:
302
+ ret_data = self(text, return_mask = return_mask)
303
+ if return_mask:
304
+ return ret_data
305
+ else:
306
+ return ret_data[0]
307
+
308
+ def encode_list(self, text_list, return_mask=True):
309
+ cont_list = []
310
+ mask_list = []
311
+ for pp in text_list:
312
+ cont = self.encode(pp, return_mask=return_mask)
313
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
314
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
315
+ if return_mask:
316
+ return cont_list, mask_list
317
+ else:
318
+ return cont_list
319
+
320
+ def encode_list_of_list(self, text_list, return_mask=True):
321
+ cont_list = []
322
+ mask_list = []
323
+ for pp in text_list:
324
+ cont = self.encode_list(pp, return_mask=return_mask)
325
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
326
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
327
+ if return_mask:
328
+ return cont_list, mask_list
329
+ else:
330
+ return cont_list
331
+
332
+ def _clean(self, text):
333
+ if self.clean == 'whitespace':
334
+ text = whitespace_clean(basic_clean(text))
335
+ elif self.clean == 'lower':
336
+ text = whitespace_clean(basic_clean(text)).lower()
337
+ elif self.clean == 'canonicalize':
338
+ text = canonicalize(basic_clean(text))
339
+ return text
340
+ @staticmethod
341
+ def get_config_template():
342
+ return dict_to_yaml('EMBEDDER',
343
+ __class__.__name__,
344
+ ACEHFEmbedder.para_dict,
345
+ set_name=True)
346
+
347
+ @EMBEDDERS.register_class()
348
+ class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
349
+ """
350
+ Uses the OpenCLIP transformer encoder for text
351
+ """
352
+ para_dict = {
353
+ 'T5_MODEL': {},
354
+ 'CLIP_MODEL': {}
355
+ }
356
+
357
+ def __init__(self, cfg, logger=None):
358
+ super().__init__(cfg, logger=logger)
359
+ self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
360
+ self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
361
+
362
+ def encode(self, text, return_mask = False):
363
+ t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
364
+ clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
365
+ # change embedding strategy here
366
+ return {
367
+ 'context': t5_embeds,
368
+ 'y': clip_embeds,
369
+ }
370
+
371
+ def encode_list(self, text, return_mask = False):
372
+ t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
373
+ clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
374
+ # change embedding strategy here
375
+ return {
376
+ 'context': t5_embeds,
377
+ 'y': clip_embeds,
378
+ }
379
+
380
+ def encode_list_of_list(self, text, return_mask = False):
381
+ t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
382
+ clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
383
+ # change embedding strategy here
384
+ return {
385
+ 'context': t5_embeds,
386
+ 'y': clip_embeds,
387
+ }
388
+
389
+
390
+ @staticmethod
391
+ def get_config_template():
392
+ return dict_to_yaml('EMBEDDER',
393
+ __class__.__name__,
394
+ T5ACEPlusClipFluxEmbedder.para_dict,
395
+ set_name=True)
396
+
397
  @BACKBONES.register_class()
398
  class Flux(BaseModel):
399
  """