JMalott commited on
Commit
497fa64
·
1 Parent(s): a56dc89

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +9 -7
min_dalle/min_dalle.py CHANGED
@@ -14,8 +14,8 @@ import gc
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
17
- torch.backends.cudnn.enabled = True
18
- torch.backends.cudnn.allow_tf32 = True
19
 
20
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
21
  IMAGE_TOKEN_COUNT = 256
@@ -25,7 +25,7 @@ class MinDalle:
25
  def __init__(
26
  self,
27
  models_root: str = 'pretrained',
28
- dtype: torch.dtype = torch.float32,
29
  device: str = None,
30
  is_mega: bool = True,
31
  is_reusable: bool = True,
@@ -188,7 +188,7 @@ class MinDalle:
188
  if len(tokens) > self.text_token_count:
189
  tokens = tokens[:self.text_token_count]
190
  if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
191
- text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
192
  text_tokens[0, :2] = [tokens[0], tokens[-1]]
193
  text_tokens[1, :len(tokens)] = tokens
194
  text_tokens = torch.tensor(
@@ -232,9 +232,11 @@ class MinDalle:
232
  token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
233
  settings = torch.tensor(
234
  [temperature, top_k, supercondition_factor],
235
- dtype=torch.float32,
236
  device=self.device
237
  )
 
 
238
  for i in range(IMAGE_TOKEN_COUNT):
239
 
240
  if(st.session_state.page != 0):
@@ -243,7 +245,7 @@ class MinDalle:
243
 
244
  #torch.cuda.empty_cache()
245
  #torch.cpu.empty_cache()
246
- gc.collect()
247
 
248
  image_tokens[i + 1], attention_state = self.decoder.forward(
249
  settings=settings,
@@ -254,7 +256,7 @@ class MinDalle:
254
  token_index=token_indices[[i]]
255
  )
256
 
257
- if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256:
258
  yield self.image_grid_from_tokens(
259
  image_tokens=image_tokens[1:].T,
260
  is_seamless=is_seamless,
 
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
17
+ torch.backends.cudnn.enabled = False
18
+ torch.backends.cudnn.allow_tf16 = False
19
 
20
  MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
21
  IMAGE_TOKEN_COUNT = 256
 
25
  def __init__(
26
  self,
27
  models_root: str = 'pretrained',
28
+ dtype: torch.dtype = torch.float16,
29
  device: str = None,
30
  is_mega: bool = True,
31
  is_reusable: bool = True,
 
188
  if len(tokens) > self.text_token_count:
189
  tokens = tokens[:self.text_token_count]
190
  if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
191
+ text_tokens = numpy.ones((2, 64), dtype=numpy.int16)
192
  text_tokens[0, :2] = [tokens[0], tokens[-1]]
193
  text_tokens[1, :len(tokens)] = tokens
194
  text_tokens = torch.tensor(
 
232
  token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device)
233
  settings = torch.tensor(
234
  [temperature, top_k, supercondition_factor],
235
+ dtype=torch.float16,
236
  device=self.device
237
  )
238
+
239
+
240
  for i in range(IMAGE_TOKEN_COUNT):
241
 
242
  if(st.session_state.page != 0):
 
245
 
246
  #torch.cuda.empty_cache()
247
  #torch.cpu.empty_cache()
248
+ #gc.collect()
249
 
250
  image_tokens[i + 1], attention_state = self.decoder.forward(
251
  settings=settings,
 
256
  token_index=token_indices[[i]]
257
  )
258
 
259
+ if ((i + 1) % 16 == 0 and progressive_outputs) or i + 1 == 256:
260
  yield self.image_grid_from_tokens(
261
  image_tokens=image_tokens[1:].T,
262
  is_seamless=is_seamless,