dar-tau commited on
Commit
b233c7d
1 Parent(s): 7595dc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -52,12 +52,15 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
52
  global_state.layers_format = model_args.pop('layers_format')
53
  tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
54
  use_ctransformers = model_args.pop('ctransformers', False)
 
55
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
56
 
57
  # get model
58
  global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None
59
  gc.collect()
60
- global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).to('cuda')
 
 
61
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
62
  gc.collect()
63
  if with_extra_components:
 
52
  global_state.layers_format = model_args.pop('layers_format')
53
  tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
54
  use_ctransformers = model_args.pop('ctransformers', False)
55
+ dont_cuda = model_args.pop('dont_cuda', False)
56
  AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
57
 
58
  # get model
59
  global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None
60
  gc.collect()
61
+ global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
62
+ if not dont_cuda:
63
+ global_state.model.to('cuda')
64
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
65
  gc.collect()
66
  if with_extra_components: