JMalott commited on
Commit
a56dc89
·
1 Parent(s): f9a4923

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +4 -1
min_dalle/min_dalle.py CHANGED
@@ -10,6 +10,7 @@ from typing import Iterator
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
 
13
 
14
  torch.set_grad_enabled(False)
15
  torch.set_num_threads(os.cpu_count())
@@ -235,12 +236,14 @@ class MinDalle:
235
  device=self.device
236
  )
237
  for i in range(IMAGE_TOKEN_COUNT):
 
238
  if(st.session_state.page != 0):
239
  break
240
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
241
 
242
  #torch.cuda.empty_cache()
243
  #torch.cpu.empty_cache()
 
244
 
245
  image_tokens[i + 1], attention_state = self.decoder.forward(
246
  settings=settings,
@@ -257,7 +260,7 @@ class MinDalle:
257
  is_seamless=is_seamless,
258
  is_verbose=is_verbose
259
  )
260
- del attention_state
261
 
262
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
263
  image_stream = self.generate_raw_image_stream(*args, **kwargs)
 
10
  from .text_tokenizer import TextTokenizer
11
  from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
12
  import streamlit as st
13
+ import gc
14
 
15
  torch.set_grad_enabled(False)
16
  torch.set_num_threads(os.cpu_count())
 
236
  device=self.device
237
  )
238
  for i in range(IMAGE_TOKEN_COUNT):
239
+
240
  if(st.session_state.page != 0):
241
  break
242
  st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
243
 
244
  #torch.cuda.empty_cache()
245
  #torch.cpu.empty_cache()
246
+ gc.collect()
247
 
248
  image_tokens[i + 1], attention_state = self.decoder.forward(
249
  settings=settings,
 
260
  is_seamless=is_seamless,
261
  is_verbose=is_verbose
262
  )
263
+
264
 
265
  def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
266
  image_stream = self.generate_raw_image_stream(*args, **kwargs)