Tao Wu commited on
Commit
94129fe
·
1 Parent(s): 28a9b71
Files changed (1) hide show
  1. app/embedding_setup.py +1 -6
app/embedding_setup.py CHANGED
@@ -34,11 +34,6 @@ retriever = db.as_retriever(search_kwargs={"k": TOP_K})
34
  lora_weights_rec = REC_LORA_MODEL
35
  lora_weights_exp = EXP_LORA_MODEL
36
  hf_auth = os.environ.get("hf_token")
37
- quantization_config = BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_compute_dtype=torch.float16,
40
- bnb_4bit_quant_type="nf4"
41
- )
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, token=hf_auth)
44
 
@@ -50,7 +45,7 @@ first_id = tokenizer.convert_tokens_to_ids(first_token)
50
  second_id = tokenizer.convert_tokens_to_ids(second_token)
51
  model = AutoModelForCausalLM.from_pretrained(
52
  LLM_MODEL,
53
- quantization_config=quantization_config,
54
  torch_dtype=torch.float16,
55
  device_map="auto",
56
  token=hf_auth,
 
34
  lora_weights_rec = REC_LORA_MODEL
35
  lora_weights_exp = EXP_LORA_MODEL
36
  hf_auth = os.environ.get("hf_token")
 
 
 
 
 
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, token=hf_auth)
39
 
 
45
  second_id = tokenizer.convert_tokens_to_ids(second_token)
46
  model = AutoModelForCausalLM.from_pretrained(
47
  LLM_MODEL,
48
+ load_in_8bit=True,
49
  torch_dtype=torch.float16,
50
  device_map="auto",
51
  token=hf_auth,