Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
028489c
·
verified ·
1 Parent(s): a3238d3

Delete text_embedding_module/frozen_clip_embedder_t3.py

Browse files
text_embedding_module/frozen_clip_embedder_t3.py DELETED
@@ -1,214 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from transformers import CLIPTextModel, CLIPTokenizer
4
- from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
5
-
6
-
7
- class AbstractEncoder(nn.Module):
8
- def __init__(self):
9
- super().__init__()
10
-
11
- def encode(self, *args, **kwargs):
12
- raise NotImplementedError
13
-
14
-
15
- class FrozenCLIPEmbedderT3(AbstractEncoder):
16
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
17
-
18
- def __init__(
19
- self,
20
- version="openai/clip-vit-large-patch14",
21
- device="cpu",
22
- max_length=77,
23
- freeze=True,
24
- use_fp16=False,
25
- ):
26
- super().__init__()
27
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
28
- self.transformer = CLIPTextModel.from_pretrained(
29
- version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
30
- ).to(device)
31
- self.device = device
32
- self.max_length = max_length
33
- if freeze:
34
- self.freeze()
35
-
36
- def embedding_forward(
37
- self,
38
- input_ids=None,
39
- position_ids=None,
40
- inputs_embeds=None,
41
- embedding_manager=None,
42
- ):
43
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
44
- if position_ids is None:
45
- position_ids = self.position_ids[:, :seq_length]
46
- if inputs_embeds is None:
47
- inputs_embeds = self.token_embedding(input_ids)
48
- if embedding_manager is not None:
49
- inputs_embeds = embedding_manager(input_ids, inputs_embeds)
50
- position_embeddings = self.position_embedding(position_ids)
51
- embeddings = inputs_embeds + position_embeddings
52
- return embeddings
53
-
54
- self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
55
- self.transformer.text_model.embeddings
56
- )
57
-
58
- def encoder_forward(
59
- self,
60
- inputs_embeds,
61
- attention_mask=None,
62
- causal_attention_mask=None,
63
- output_attentions=None,
64
- output_hidden_states=None,
65
- return_dict=None,
66
- ):
67
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
- output_hidden_states = (
69
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
- )
71
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
- encoder_states = () if output_hidden_states else None
73
- all_attentions = () if output_attentions else None
74
- hidden_states = inputs_embeds
75
- for idx, encoder_layer in enumerate(self.layers):
76
- if output_hidden_states:
77
- encoder_states = encoder_states + (hidden_states,)
78
- layer_outputs = encoder_layer(
79
- hidden_states,
80
- attention_mask,
81
- causal_attention_mask,
82
- output_attentions=output_attentions,
83
- )
84
- hidden_states = layer_outputs[0]
85
- if output_attentions:
86
- all_attentions = all_attentions + (layer_outputs[1],)
87
- if output_hidden_states:
88
- encoder_states = encoder_states + (hidden_states,)
89
- return hidden_states
90
-
91
- self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
92
-
93
- def text_encoder_forward(
94
- self,
95
- input_ids=None,
96
- attention_mask=None,
97
- position_ids=None,
98
- output_attentions=None,
99
- output_hidden_states=None,
100
- return_dict=None,
101
- embedding_manager=None,
102
- ):
103
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
104
- output_hidden_states = (
105
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
106
- )
107
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
108
- if input_ids is None:
109
- raise ValueError("You have to specify either input_ids")
110
- input_shape = input_ids.size()
111
- input_ids = input_ids.view(-1, input_shape[-1])
112
- hidden_states = self.embeddings(
113
- input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
114
- )
115
- # CLIP's text model uses causal mask, prepare it here.
116
- # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
117
- causal_attention_mask = _create_4d_causal_attention_mask(
118
- input_shape, hidden_states.dtype, device=hidden_states.device
119
- )
120
- # expand attention_mask
121
- if attention_mask is not None:
122
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
123
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
124
- last_hidden_state = self.encoder(
125
- inputs_embeds=hidden_states,
126
- attention_mask=attention_mask,
127
- causal_attention_mask=causal_attention_mask,
128
- output_attentions=output_attentions,
129
- output_hidden_states=output_hidden_states,
130
- return_dict=return_dict,
131
- )
132
- last_hidden_state = self.final_layer_norm(last_hidden_state)
133
- return last_hidden_state
134
-
135
- self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
136
-
137
- def transformer_forward(
138
- self,
139
- input_ids=None,
140
- attention_mask=None,
141
- position_ids=None,
142
- output_attentions=None,
143
- output_hidden_states=None,
144
- return_dict=None,
145
- embedding_manager=None,
146
- ):
147
- return self.text_model(
148
- input_ids=input_ids,
149
- attention_mask=attention_mask,
150
- position_ids=position_ids,
151
- output_attentions=output_attentions,
152
- output_hidden_states=output_hidden_states,
153
- return_dict=return_dict,
154
- embedding_manager=embedding_manager,
155
- )
156
-
157
- self.transformer.forward = transformer_forward.__get__(self.transformer)
158
-
159
- def freeze(self):
160
- self.transformer = self.transformer.eval()
161
- for param in self.parameters():
162
- param.requires_grad = False
163
-
164
- def forward(self, text, **kwargs):
165
- batch_encoding = self.tokenizer(
166
- text,
167
- truncation=False,
168
- max_length=self.max_length,
169
- return_length=True,
170
- return_overflowing_tokens=False,
171
- padding="longest",
172
- return_tensors="pt",
173
- )
174
- input_ids = batch_encoding["input_ids"]
175
- tokens_list = self.split_chunks(input_ids)
176
- z_list = []
177
- for tokens in tokens_list:
178
- tokens = tokens.to(self.device)
179
- _z = self.transformer(input_ids=tokens, **kwargs)
180
- z_list += [_z]
181
- return torch.cat(z_list, dim=1)
182
-
183
- def encode(self, text, **kwargs):
184
- return self(text, **kwargs)
185
-
186
- def split_chunks(self, input_ids, chunk_size=75):
187
- tokens_list = []
188
- bs, n = input_ids.shape
189
- id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
190
- id_end = input_ids[:, -1].unsqueeze(1)
191
- if n == 2: # empty caption
192
- tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
193
-
194
- trimmed_encoding = input_ids[:, 1:-1]
195
- num_full_groups = (n - 2) // chunk_size
196
-
197
- for i in range(num_full_groups):
198
- group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
199
- group_pad = torch.cat((id_start, group, id_end), dim=1)
200
- tokens_list.append(group_pad)
201
-
202
- remaining_columns = (n - 2) % chunk_size
203
- if remaining_columns > 0:
204
- remaining_group = trimmed_encoding[:, -remaining_columns:]
205
- padding_columns = chunk_size - remaining_group.shape[1]
206
- padding = id_end.expand(bs, padding_columns)
207
- remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
208
- tokens_list.append(remaining_group_pad)
209
- return tokens_list
210
-
211
- def to(self, *args, **kwargs):
212
- self.transformer = self.transformer.to(*args, **kwargs)
213
- self.device = self.transformer.device
214
- return self