JMalott commited on
Commit
77bd19b
·
1 Parent(s): 7f311ef

Update min_dalle/min_dalle.py

Browse files
Files changed (1) hide show
  1. min_dalle/min_dalle.py +8 -8
min_dalle/min_dalle.py CHANGED
@@ -66,7 +66,7 @@ class MinDalle:
66
  self.init_decoder()
67
  self.init_detokenizer()
68
 
69
-
70
  def download_tokenizer(self):
71
  if self.is_verbose: print("downloading tokenizer params")
72
  suffix = '' if self.is_mega else '_mini'
@@ -76,27 +76,27 @@ class MinDalle:
76
  with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
77
  with open(self.merges_path, 'wb') as f: f.write(merges.content)
78
 
79
-
80
  def download_encoder(self):
81
  if self.is_verbose: print("downloading encoder params")
82
  suffix = '' if self.is_mega else '_mini'
83
  params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
84
  with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
85
 
86
-
87
  def download_decoder(self):
88
  if self.is_verbose: print("downloading decoder params")
89
  suffix = '' if self.is_mega else '_mini'
90
  params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
91
  with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
92
 
93
-
94
  def download_detokenizer(self):
95
  if self.is_verbose: print("downloading detokenizer params")
96
  params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
97
  with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
98
 
99
-
100
  def init_tokenizer(self):
101
  is_downloaded = os.path.exists(self.vocab_path)
102
  is_downloaded &= os.path.exists(self.merges_path)
@@ -108,7 +108,7 @@ class MinDalle:
108
  merges = f.read().split("\n")[1:-1]
109
  self.tokenizer = TextTokenizer(vocab, merges)
110
 
111
-
112
  def init_encoder(self):
113
  is_downloaded = os.path.exists(self.encoder_params_path)
114
  if not is_downloaded: self.download_encoder()
@@ -127,7 +127,7 @@ class MinDalle:
127
  del params
128
  self.encoder = self.encoder.to(device=self.device)
129
 
130
-
131
  def init_decoder(self):
132
  is_downloaded = os.path.exists(self.decoder_params_path)
133
  if not is_downloaded: self.download_decoder()
@@ -145,7 +145,7 @@ class MinDalle:
145
  del params
146
  self.decoder = self.decoder.to(device=self.device)
147
 
148
-
149
  def init_detokenizer(self):
150
  is_downloaded = os.path.exists(self.detoker_params_path)
151
  if not is_downloaded: self.download_detokenizer()
 
66
  self.init_decoder()
67
  self.init_detokenizer()
68
 
69
+ @st.cache
70
  def download_tokenizer(self):
71
  if self.is_verbose: print("downloading tokenizer params")
72
  suffix = '' if self.is_mega else '_mini'
 
76
  with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
77
  with open(self.merges_path, 'wb') as f: f.write(merges.content)
78
 
79
+ @st.cache
80
  def download_encoder(self):
81
  if self.is_verbose: print("downloading encoder params")
82
  suffix = '' if self.is_mega else '_mini'
83
  params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
84
  with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
85
 
86
+ @st.cache
87
  def download_decoder(self):
88
  if self.is_verbose: print("downloading decoder params")
89
  suffix = '' if self.is_mega else '_mini'
90
  params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
91
  with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
92
 
93
+ @st.cache
94
  def download_detokenizer(self):
95
  if self.is_verbose: print("downloading detokenizer params")
96
  params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
97
  with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
98
 
99
+ @st.cache
100
  def init_tokenizer(self):
101
  is_downloaded = os.path.exists(self.vocab_path)
102
  is_downloaded &= os.path.exists(self.merges_path)
 
108
  merges = f.read().split("\n")[1:-1]
109
  self.tokenizer = TextTokenizer(vocab, merges)
110
 
111
+ @st.cache
112
  def init_encoder(self):
113
  is_downloaded = os.path.exists(self.encoder_params_path)
114
  if not is_downloaded: self.download_encoder()
 
127
  del params
128
  self.encoder = self.encoder.to(device=self.device)
129
 
130
+ @st.cache
131
  def init_decoder(self):
132
  is_downloaded = os.path.exists(self.decoder_params_path)
133
  if not is_downloaded: self.download_decoder()
 
145
  del params
146
  self.decoder = self.decoder.to(device=self.device)
147
 
148
+ @st.cache
149
  def init_detokenizer(self):
150
  is_downloaded = os.path.exists(self.detoker_params_path)
151
  if not is_downloaded: self.download_detokenizer()