JMalott commited on
Commit
07d6419
·
1 Parent(s): da99fb7

Upload min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +291 -0
min_dalle/min_dalle.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy
4
+ from torch import LongTensor, FloatTensor
5
+ import torch
6
+ import torch.backends.cudnn, torch.backends.cuda
7
+ import json
8
+ import requests
9
+ from typing import Iterator
10
+ from .text_tokenizer import TextTokenizer
11
+ from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
+ import streamlit as st
13
+
14
+ torch.set_grad_enabled(False)
15
+ torch.set_num_threads(os.cpu_count())
16
+ torch.backends.cudnn.enabled = True
17
+ torch.backends.cudnn.allow_tf32 = True
18
+
19
+ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
20
+ IMAGE_TOKEN_COUNT = 256
21
+
22
+
23
+ class MinDalle:
24
+ def __init__(
25
+ self,
26
+ models_root: str = 'pretrained',
27
+ dtype: torch.dtype = torch.float32,
28
+ device: str = None,
29
+ is_mega: bool = True,
30
+ is_reusable: bool = True,
31
+ is_verbose = True
32
+ ):
33
+ if device == None:
34
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
35
+ if is_verbose: print("using device", device)
36
+ self.device = device
37
+ self.is_mega = is_mega
38
+ self.is_reusable = is_reusable
39
+ self.dtype = dtype
40
+ self.is_verbose = is_verbose
41
+ self.text_token_count = 64
42
+ self.layer_count = 24 if is_mega else 12
43
+ self.attention_head_count = 32 if is_mega else 16
44
+ self.embed_count = 2048 if is_mega else 1024
45
+ self.glu_embed_count = 4096 if is_mega else 2730
46
+ self.text_vocab_count = 50272 if is_mega else 50264
47
+ self.image_vocab_count = 16415 if is_mega else 16384
48
+
49
+ model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
50
+ dalle_path = os.path.join(models_root, model_name)
51
+ vqgan_path = os.path.join(models_root, 'vqgan')
52
+ if not os.path.exists(dalle_path): os.makedirs(dalle_path)
53
+ if not os.path.exists(vqgan_path): os.makedirs(vqgan_path)
54
+ self.vocab_path = os.path.join(dalle_path, 'vocab.json')
55
+ self.merges_path = os.path.join(dalle_path, 'merges.txt')
56
+ self.encoder_params_path = os.path.join(dalle_path, 'encoder.pt')
57
+ self.decoder_params_path = os.path.join(dalle_path, 'decoder.pt')
58
+ self.detoker_params_path = os.path.join(vqgan_path, 'detoker.pt')
59
+
60
+ self.init_tokenizer()
61
+ if is_reusable:
62
+ self.init_encoder()
63
+ self.init_decoder()
64
+ self.init_detokenizer()
65
+
66
+
67
+ def download_tokenizer(self):
68
+ if self.is_verbose: print("downloading tokenizer params")
69
+ suffix = '' if self.is_mega else '_mini'
70
+ _ = requests.get(MIN_DALLE_REPO + 'config.json') # trigger HF download
71
+ vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix))
72
+ merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix))
73
+ with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
74
+ with open(self.merges_path, 'wb') as f: f.write(merges.content)
75
+
76
+
77
+ def download_encoder(self):
78
+ if self.is_verbose: print("downloading encoder params")
79
+ suffix = '' if self.is_mega else '_mini'
80
+ params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
81
+ with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
82
+
83
+
84
+ def download_decoder(self):
85
+ if self.is_verbose: print("downloading decoder params")
86
+ suffix = '' if self.is_mega else '_mini'
87
+ params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
88
+ with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
89
+
90
+
91
+ def download_detokenizer(self):
92
+ if self.is_verbose: print("downloading detokenizer params")
93
+ params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
94
+ with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
95
+
96
+
97
+ def init_tokenizer(self):
98
+ is_downloaded = os.path.exists(self.vocab_path)
99
+ is_downloaded &= os.path.exists(self.merges_path)
100
+ if not is_downloaded: self.download_tokenizer()
101
+ if self.is_verbose: print("intializing TextTokenizer")
102
+ with open(self.vocab_path, 'r', encoding='utf8') as f:
103
+ vocab = json.load(f)
104
+ with open(self.merges_path, 'r', encoding='utf8') as f:
105
+ merges = f.read().split("\n")[1:-1]
106
+ self.tokenizer = TextTokenizer(vocab, merges)
107
+
108
+
109
+ def init_encoder(self):
110
+ is_downloaded = os.path.exists(self.encoder_params_path)
111
+ if not is_downloaded: self.download_encoder()
112
+ if self.is_verbose: print("initializing DalleBartEncoder")
113
+ self.encoder = DalleBartEncoder(
114
+ attention_head_count = self.attention_head_count,
115
+ embed_count = self.embed_count,
116
+ glu_embed_count = self.glu_embed_count,
117
+ text_token_count = self.text_token_count,
118
+ text_vocab_count = self.text_vocab_count,
119
+ layer_count = self.layer_count,
120
+ device=self.device
121
+ ).to(self.dtype).eval()
122
+ params = torch.load(self.encoder_params_path)
123
+ self.encoder.load_state_dict(params, strict=False)
124
+ del params
125
+ self.encoder = self.encoder.to(device=self.device)
126
+
127
+
128
+ def init_decoder(self):
129
+ is_downloaded = os.path.exists(self.decoder_params_path)
130
+ if not is_downloaded: self.download_decoder()
131
+ if self.is_verbose: print("initializing DalleBartDecoder")
132
+ self.decoder = DalleBartDecoder(
133
+ image_vocab_count = self.image_vocab_count,
134
+ attention_head_count = self.attention_head_count,
135
+ embed_count = self.embed_count,
136
+ glu_embed_count = self.glu_embed_count,
137
+ layer_count = self.layer_count,
138
+ device=self.device
139
+ ).to(self.dtype).eval()
140
+ params = torch.load(self.decoder_params_path)
141
+ self.decoder.load_state_dict(params, strict=False)
142
+ del params
143
+ self.decoder = self.decoder.to(device=self.device)
144
+
145
+
146
+ def init_detokenizer(self):
147
+ is_downloaded = os.path.exists(self.detoker_params_path)
148
+ if not is_downloaded: self.download_detokenizer()
149
+ if self.is_verbose: print("initializing VQGanDetokenizer")
150
+ self.detokenizer = VQGanDetokenizer().eval()
151
+ params = torch.load(self.detoker_params_path)
152
+ self.detokenizer.load_state_dict(params)
153
+ del params
154
+ self.detokenizer = self.detokenizer.to(device=self.device)
155
+
156
+
157
+ def image_grid_from_tokens(
158
+ self,
159
+ image_tokens: LongTensor,
160
+ is_seamless: bool,
161
+ is_verbose: bool = False
162
+ ) -> FloatTensor:
163
+ if not self.is_reusable: del self.decoder
164
+ torch.cuda.empty_cache()
165
+ if not self.is_reusable: self.init_detokenizer()
166
+ if is_verbose: print("detokenizing image")
167
+ images = self.detokenizer.forward(is_seamless, image_tokens)
168
+ if not self.is_reusable: del self.detokenizer
169
+ return images
170
+
171
+
172
+ def generate_raw_image_stream(
173
+ self,
174
+ text: str,
175
+ seed: int,
176
+ grid_size: int,
177
+ progressive_outputs: bool = False,
178
+ is_seamless: bool = False,
179
+ temperature: float = 1,
180
+ top_k: int = 256,
181
+ supercondition_factor: int = 16,
182
+ is_verbose: bool = False
183
+ ) -> Iterator[FloatTensor]:
184
+ image_count = grid_size ** 2
185
+ if is_verbose: print("tokenizing text")
186
+ tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
187
+ if len(tokens) > self.text_token_count:
188
+ tokens = tokens[:self.text_token_count]
189
+ if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
190
+ text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
191
+ text_tokens[0, :2] = [tokens[0], tokens[-1]]
192
+ text_tokens[1, :len(tokens)] = tokens
193
+ text_tokens = torch.tensor(
194
+ text_tokens,
195
+ dtype=torch.long,
196
+ device=self.device
197
+ )
198
+
199
+ if not self.is_reusable: self.init_encoder()
200
+ if is_verbose: print("encoding text tokens")
201
+ with torch.cuda.amp.autocast(dtype=self.dtype):
202
+ encoder_state = self.encoder.forward(text_tokens)
203
+ if not self.is_reusable: del self.encoder
204
+ torch.cuda.empty_cache()
205
+
206
+ if not self.is_reusable: self.init_decoder()
207
+
208
+ with torch.cuda.amp.autocast(dtype=self.dtype):
209
+ expanded_indices = [0] * image_count + [1] * image_count
210
+ text_tokens = text_tokens[expanded_indices]
211
+ encoder_state = encoder_state[expanded_indices]
212
+ attention_mask = text_tokens.not_equal(1)
213
+ attention_state = torch.zeros(
214
+ size=(
215
+ self.layer_count,
216
+ image_count * 4,
217
+ IMAGE_TOKEN_COUNT,
218
+ self.embed_count
219
+ ),
220
+ device=self.device
221
+ )
222
+ image_tokens = torch.full(
223
+ (IMAGE_TOKEN_COUNT + 1, image_count),
224
+ self.image_vocab_count,
225
+ dtype=torch.long,
226
+ device=self.device
227
+ )
228
+
229
+ if seed > 0: torch.manual_seed(seed)
230
+
231
+ token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
232
+ settings = torch.tensor(
233
+ [temperature, top_k, supercondition_factor],
234
+ dtype=torch.float32,
235
+ device=self.device
236
+ )
237
+ for i in range(IMAGE_TOKEN_COUNT):
238
+ if(st.session_state.page != 0):
239
+ break
240
+ st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
+
242
+ torch.cuda.empty_cache()
243
+ with torch.cuda.amp.autocast(dtype=self.dtype):
244
+ image_tokens[i + 1], attention_state = self.decoder.forward(
245
+ settings=settings,
246
+ attention_mask=attention_mask,
247
+ encoder_state=encoder_state,
248
+ attention_state=attention_state,
249
+ prev_tokens=image_tokens[i],
250
+ token_index=token_indices[[i]]
251
+ )
252
+
253
+ with torch.cuda.amp.autocast(dtype=torch.float32):
254
+ if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
255
+ yield self.image_grid_from_tokens(
256
+ image_tokens=image_tokens[1:].T,
257
+ is_seamless=is_seamless,
258
+ is_verbose=is_verbose
259
+ )
260
+
261
+ def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
262
+ image_stream = self.generate_raw_image_stream(*args, **kwargs)
263
+ for image in image_stream:
264
+ image = image.to(torch.uint8).to('cpu').numpy()
265
+ yield Image.fromarray(image)
266
+
267
+
268
+ def generate_images_stream(self, *args, **kwargs) -> Iterator[FloatTensor]:
269
+ image_stream = self.generate_raw_image_stream(*args, **kwargs)
270
+ for image in image_stream:
271
+ grid_size = kwargs['grid_size']
272
+ image = image.view([grid_size * 256, grid_size, 256, 3])
273
+ image = image.transpose(1, 0)
274
+ image = image.reshape([grid_size ** 2, 2 ** 8, 2 ** 8, 3])
275
+ yield image
276
+
277
+
278
+ def generate_image(self, *args, **kwargs) -> Image.Image:
279
+ image_stream = self.generate_image_stream(
280
+ *args, **kwargs,
281
+ progressive_outputs=False
282
+ )
283
+ return next(image_stream)
284
+
285
+
286
+ def generate_images(self, *args, **kwargs) -> Image.Image:
287
+ images_stream = self.generate_images_stream(
288
+ *args, **kwargs,
289
+ progressive_outputs=False
290
+ )
291
+ return next(images_stream)