Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
3809f07
·
verified ·
1 Parent(s): a786082

Upload frozen_clip_embedder_t3.py

Browse files
text_embedding_module/frozen_clip_embedder_t3.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
5
+
6
+
7
+ class AbstractEncoder(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def encode(self, *args, **kwargs):
12
+ raise NotImplementedError
13
+
14
+
15
+ class FrozenCLIPEmbedderT3(AbstractEncoder):
16
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
17
+
18
+ def __init__(
19
+ self,
20
+ version="openai/clip-vit-large-patch14",
21
+ device="cpu",
22
+ max_length=77,
23
+ freeze=True,
24
+ use_fp16=False,
25
+ ):
26
+ super().__init__()
27
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
28
+ self.transformer = CLIPTextModel.from_pretrained(
29
+ version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
30
+ ).to(device)
31
+ self.device = device
32
+ self.max_length = max_length
33
+ if freeze:
34
+ self.freeze()
35
+
36
+ def embedding_forward(
37
+ self,
38
+ input_ids=None,
39
+ position_ids=None,
40
+ inputs_embeds=None,
41
+ embedding_manager=None,
42
+ ):
43
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
44
+ if position_ids is None:
45
+ position_ids = self.position_ids[:, :seq_length]
46
+ if inputs_embeds is None:
47
+ inputs_embeds = self.token_embedding(input_ids)
48
+ if embedding_manager is not None:
49
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
50
+ position_embeddings = self.position_embedding(position_ids)
51
+ embeddings = inputs_embeds + position_embeddings
52
+ return embeddings
53
+
54
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
55
+ self.transformer.text_model.embeddings
56
+ )
57
+
58
+ def encoder_forward(
59
+ self,
60
+ inputs_embeds,
61
+ attention_mask=None,
62
+ causal_attention_mask=None,
63
+ output_attentions=None,
64
+ output_hidden_states=None,
65
+ return_dict=None,
66
+ ):
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+ encoder_states = () if output_hidden_states else None
73
+ all_attentions = () if output_attentions else None
74
+ hidden_states = inputs_embeds
75
+ for idx, encoder_layer in enumerate(self.layers):
76
+ if output_hidden_states:
77
+ encoder_states = encoder_states + (hidden_states,)
78
+ layer_outputs = encoder_layer(
79
+ hidden_states,
80
+ attention_mask,
81
+ causal_attention_mask,
82
+ output_attentions=output_attentions,
83
+ )
84
+ hidden_states = layer_outputs[0]
85
+ if output_attentions:
86
+ all_attentions = all_attentions + (layer_outputs[1],)
87
+ if output_hidden_states:
88
+ encoder_states = encoder_states + (hidden_states,)
89
+ return hidden_states
90
+
91
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
92
+
93
+ def text_encoder_forward(
94
+ self,
95
+ input_ids=None,
96
+ attention_mask=None,
97
+ position_ids=None,
98
+ output_attentions=None,
99
+ output_hidden_states=None,
100
+ return_dict=None,
101
+ embedding_manager=None,
102
+ ):
103
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
104
+ output_hidden_states = (
105
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
106
+ )
107
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
108
+ if input_ids is None:
109
+ raise ValueError("You have to specify either input_ids")
110
+ input_shape = input_ids.size()
111
+ input_ids = input_ids.view(-1, input_shape[-1])
112
+ hidden_states = self.embeddings(
113
+ input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
114
+ )
115
+ # CLIP's text model uses causal mask, prepare it here.
116
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
117
+ causal_attention_mask = _create_4d_causal_attention_mask(
118
+ input_shape, hidden_states.dtype, device=hidden_states.device
119
+ )
120
+ # expand attention_mask
121
+ if attention_mask is not None:
122
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
123
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
124
+ last_hidden_state = self.encoder(
125
+ inputs_embeds=hidden_states,
126
+ attention_mask=attention_mask,
127
+ causal_attention_mask=causal_attention_mask,
128
+ output_attentions=output_attentions,
129
+ output_hidden_states=output_hidden_states,
130
+ return_dict=return_dict,
131
+ )
132
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
133
+ return last_hidden_state
134
+
135
+ self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
136
+
137
+ def transformer_forward(
138
+ self,
139
+ input_ids=None,
140
+ attention_mask=None,
141
+ position_ids=None,
142
+ output_attentions=None,
143
+ output_hidden_states=None,
144
+ return_dict=None,
145
+ embedding_manager=None,
146
+ ):
147
+ return self.text_model(
148
+ input_ids=input_ids,
149
+ attention_mask=attention_mask,
150
+ position_ids=position_ids,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ embedding_manager=embedding_manager,
155
+ )
156
+
157
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
158
+
159
+ def freeze(self):
160
+ self.transformer = self.transformer.eval()
161
+ for param in self.parameters():
162
+ param.requires_grad = False
163
+
164
+ def forward(self, text, **kwargs):
165
+ batch_encoding = self.tokenizer(
166
+ text,
167
+ truncation=False,
168
+ max_length=self.max_length,
169
+ return_length=True,
170
+ return_overflowing_tokens=False,
171
+ padding="longest",
172
+ return_tensors="pt",
173
+ )
174
+ input_ids = batch_encoding["input_ids"]
175
+ tokens_list = self.split_chunks(input_ids)
176
+ z_list = []
177
+ for tokens in tokens_list:
178
+ tokens = tokens.to(self.device)
179
+ _z = self.transformer(input_ids=tokens, **kwargs)
180
+ z_list += [_z]
181
+ return torch.cat(z_list, dim=1)
182
+
183
+ def encode(self, text, **kwargs):
184
+ return self(text, **kwargs)
185
+
186
+ def split_chunks(self, input_ids, chunk_size=75):
187
+ tokens_list = []
188
+ bs, n = input_ids.shape
189
+ id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
190
+ id_end = input_ids[:, -1].unsqueeze(1)
191
+ if n == 2: # empty caption
192
+ tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
193
+
194
+ trimmed_encoding = input_ids[:, 1:-1]
195
+ num_full_groups = (n - 2) // chunk_size
196
+
197
+ for i in range(num_full_groups):
198
+ group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
199
+ group_pad = torch.cat((id_start, group, id_end), dim=1)
200
+ tokens_list.append(group_pad)
201
+
202
+ remaining_columns = (n - 2) % chunk_size
203
+ if remaining_columns > 0:
204
+ remaining_group = trimmed_encoding[:, -remaining_columns:]
205
+ padding_columns = chunk_size - remaining_group.shape[1]
206
+ padding = id_end.expand(bs, padding_columns)
207
+ remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
208
+ tokens_list.append(remaining_group_pad)
209
+ return tokens_list
210
+
211
+ def to(self, *args, **kwargs):
212
+ self.transformer = self.transformer.to(*args, **kwargs)
213
+ self.device = self.transformer.device
214
+ return self