keerthanaBasavaraj commited on
Commit
ee8f33b
·
1 Parent(s): 1d0a22d

bitsandbytes for Model Loading

Browse files
Files changed (1) hide show
  1. sql_query_generator/generator.py +6 -2
sql_query_generator/generator.py CHANGED
@@ -2,10 +2,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
  def load_model(model_name="chatdb/natural-sql-7b"):
4
  """
5
- Loads the SQL generation model and tokenizer from Hugging Face.
6
  """
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
9
  return tokenizer, model
10
 
11
  def generate_sql(question, prompt_inputs, tokenizer, model, device="cpu"):
 
2
 
3
  def load_model(model_name="chatdb/natural-sql-7b"):
4
  """
5
+ Loads the SQL generation model with 8-bit precision.
6
  """
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ quantization_config = BitsAndBytesConfig(
9
+ load_in_8bit=True, # Enable 8-bit loading
10
+ llm_int8_threshold=6.0 # Fine-tune threshold if needed
11
+ )
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
13
  return tokenizer, model
14
 
15
  def generate_sql(question, prompt_inputs, tokenizer, model, device="cpu"):