shauray commited on
Commit
b95c9fa
·
1 Parent(s): 03a8fc7

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -6
model.py CHANGED
@@ -6,24 +6,20 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIt
6
 
7
  model_id = 'abacaj/starcoderbase-1b-sft'
8
 
 
 
9
  if torch.cuda.is_available():
10
- config = AutoConfig.from_pretrained(model_id)
11
- config.pretraining_tp = 1
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
  config=config,
15
  torch_dtype=torch.float16,
16
- load_in_4bit=True,
17
  device_map='cuda',
18
  )
19
  else:
20
- config = AutoConfig.from_pretrained(model_id)
21
- config.pretraining_tp = 1
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_id,
24
  config=config,
25
  torch_dtype=torch.float32,
26
- load_in_4bit=True,
27
  )
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
29
 
 
6
 
7
  model_id = 'abacaj/starcoderbase-1b-sft'
8
 
9
+ config = AutoConfig.from_pretrained(model_id)
10
+ config.pretraining_tp = 1
11
  if torch.cuda.is_available():
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_id,
14
  config=config,
15
  torch_dtype=torch.float16,
 
16
  device_map='cuda',
17
  )
18
  else:
 
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  config=config,
22
  torch_dtype=torch.float32,
 
23
  )
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25