zetavg commited on
Commit
89dd922
·
unverified ·
1 Parent(s): 90c428d
llama_lora/models.py CHANGED
@@ -150,134 +150,3 @@ def unload_models():
150
  Global.loaded_models.clear()
151
  Global.loaded_tokenizers.clear()
152
  clear_cache()
153
-
154
-
155
-
156
-
157
-
158
- ########
159
-
160
- # def get_base_model():
161
- # load_base_model()
162
- # return Global.loaded_base_model
163
-
164
-
165
- # def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
166
- # # Global.model_has_been_used = True
167
- # #
168
- # #
169
- # if Global.loaded_tokenizer is None:
170
- # Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
171
- # Global.base_model
172
- # )
173
-
174
- # if Global.cached_lora_models:
175
- # model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
176
- # if model_from_cache:
177
- # return model_from_cache
178
-
179
- # Global.cached_lora_models.prepare_to_set()
180
-
181
- # if device == "cuda":
182
- # model = PeftModel.from_pretrained(
183
- # get_new_base_model(),
184
- # lora_weights_name_or_path,
185
- # torch_dtype=torch.float16,
186
- # device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
187
- # )
188
- # elif device == "mps":
189
- # model = PeftModel.from_pretrained(
190
- # get_new_base_model(),
191
- # lora_weights_name_or_path,
192
- # device_map={"": device},
193
- # torch_dtype=torch.float16,
194
- # )
195
- # else:
196
- # model = PeftModel.from_pretrained(
197
- # get_new_base_model(),
198
- # lora_weights_name_or_path,
199
- # device_map={"": device},
200
- # )
201
-
202
- # model.config.pad_token_id = get_tokenizer().pad_token_id = 0
203
- # model.config.bos_token_id = 1
204
- # model.config.eos_token_id = 2
205
-
206
- # if not Global.load_8bit:
207
- # model.half() # seems to fix bugs for some users.
208
-
209
- # model.eval()
210
- # if torch.__version__ >= "2" and sys.platform != "win32":
211
- # model = torch.compile(model)
212
-
213
- # if Global.cached_lora_models:
214
- # Global.cached_lora_models.set(lora_weights_name_or_path, model)
215
-
216
- # clear_cache()
217
-
218
- # return model
219
-
220
-
221
-
222
-
223
-
224
- # def load_base_model():
225
- # return;
226
-
227
- # if Global.ui_dev_mode:
228
- # return
229
-
230
- # if Global.loaded_tokenizer is None:
231
- # Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
232
- # Global.base_model
233
- # )
234
- # if Global.loaded_base_model is None:
235
- # if device == "cuda":
236
- # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
237
- # Global.base_model,
238
- # load_in_8bit=Global.load_8bit,
239
- # torch_dtype=torch.float16,
240
- # # device_map="auto",
241
- # device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
242
- # )
243
- # elif device == "mps":
244
- # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
245
- # Global.base_model,
246
- # device_map={"": device},
247
- # torch_dtype=torch.float16,
248
- # )
249
- # else:
250
- # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
251
- # Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
252
- # )
253
-
254
- # Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
255
- # Global.loaded_base_model.config.bos_token_id = 1
256
- # Global.loaded_base_model.config.eos_token_id = 2
257
-
258
-
259
- # def clear_cache():
260
- # gc.collect()
261
-
262
- # # if not shared.args.cpu: # will not be running on CPUs anyway
263
- # with torch.no_grad():
264
- # torch.cuda.empty_cache()
265
-
266
-
267
- # def unload_models():
268
- # del Global.loaded_base_model
269
- # Global.loaded_base_model = None
270
-
271
- # del Global.loaded_tokenizer
272
- # Global.loaded_tokenizer = None
273
-
274
- # Global.cached_lora_models.clear()
275
-
276
- # clear_cache()
277
-
278
- # Global.model_has_been_used = False
279
-
280
-
281
- # def unload_models_if_already_used():
282
- # if Global.model_has_been_used:
283
- # unload_models()
 
150
  Global.loaded_models.clear()
151
  Global.loaded_tokenizers.clear()
152
  clear_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llama_lora/ui/inference_ui.py CHANGED
@@ -95,7 +95,7 @@ def do_inference(
95
  return
96
  time.sleep(1)
97
  yield (
98
- gr.Textbox.update(value=message, lines=1), # TODO
99
  json.dumps(list(range(len(message.split()))), indent=2)
100
  )
101
  return
 
95
  return
96
  time.sleep(1)
97
  yield (
98
+ gr.Textbox.update(value=message, lines=inference_output_lines),
99
  json.dumps(list(range(len(message.split()))), indent=2)
100
  )
101
  return