Spaces:
Runtime error
Runtime error
Update min_dalle/min_dalle.py
Browse files- min_dalle/min_dalle.py +4 -1
min_dalle/min_dalle.py
CHANGED
@@ -10,6 +10,7 @@ from typing import Iterator
|
|
10 |
from .text_tokenizer import TextTokenizer
|
11 |
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
12 |
import streamlit as st
|
|
|
13 |
|
14 |
torch.set_grad_enabled(False)
|
15 |
torch.set_num_threads(os.cpu_count())
|
@@ -235,12 +236,14 @@ class MinDalle:
|
|
235 |
device=self.device
|
236 |
)
|
237 |
for i in range(IMAGE_TOKEN_COUNT):
|
|
|
238 |
if(st.session_state.page != 0):
|
239 |
break
|
240 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
241 |
|
242 |
#torch.cuda.empty_cache()
|
243 |
#torch.cpu.empty_cache()
|
|
|
244 |
|
245 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
246 |
settings=settings,
|
@@ -257,7 +260,7 @@ class MinDalle:
|
|
257 |
is_seamless=is_seamless,
|
258 |
is_verbose=is_verbose
|
259 |
)
|
260 |
-
|
261 |
|
262 |
def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
|
263 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|
|
|
10 |
from .text_tokenizer import TextTokenizer
|
11 |
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
12 |
import streamlit as st
|
13 |
+
import gc
|
14 |
|
15 |
torch.set_grad_enabled(False)
|
16 |
torch.set_num_threads(os.cpu_count())
|
|
|
236 |
device=self.device
|
237 |
)
|
238 |
for i in range(IMAGE_TOKEN_COUNT):
|
239 |
+
|
240 |
if(st.session_state.page != 0):
|
241 |
break
|
242 |
st.session_state.bar.progress(i/IMAGE_TOKEN_COUNT)
|
243 |
|
244 |
#torch.cuda.empty_cache()
|
245 |
#torch.cpu.empty_cache()
|
246 |
+
gc.collect()
|
247 |
|
248 |
image_tokens[i + 1], attention_state = self.decoder.forward(
|
249 |
settings=settings,
|
|
|
260 |
is_seamless=is_seamless,
|
261 |
is_verbose=is_verbose
|
262 |
)
|
263 |
+
|
264 |
|
265 |
def generate_image_stream(self, *args, **kwargs) -> Iterator[Image.Image]:
|
266 |
image_stream = self.generate_raw_image_stream(*args, **kwargs)
|