drmasad commited on
Commit
ce8c007
·
verified ·
1 Parent(s): 50b517b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -46,29 +46,23 @@ def load_model(selected_model_name):
46
  st.info("Loading the model, please wait...")
47
  model_name = model_links[selected_model_name]
48
 
49
- # Load the model without a device map
50
- model = AutoModelForCausalLM.from_pretrained(model_name)
51
-
52
- # Check the availability of CUDA
53
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
54
- # Manually move the model to the device
55
- model = model.to(device)
56
-
57
- # Apply quantization configuration if required
58
- if device == 'cuda': # Only apply BitsAndBytesConfig if on CUDA
59
- bnb_config = BitsAndBytesConfig(
60
- load_in_4bit=True,
61
- bnb_4bit_quant_type="nf4",
62
- bnb_4bit_compute_dtype=torch.bfloat16,
63
- bnb_4bit_use_double_quant=False,
64
- llm_int8_enable_fp32_cpu_offload=False,
65
- )
66
- # Assume quantization applies here, adjust as per actual use case
67
- # model = apply_quantization(model, bnb_config)
68
 
 
 
 
 
 
 
 
 
 
 
69
  model.config.use_cache = False
70
  model = prepare_model_for_kbit_training(model)
71
 
 
72
  peft_config = LoraConfig(
73
  lora_alpha=16,
74
  lora_dropout=0.1,
@@ -77,7 +71,6 @@ def load_model(selected_model_name):
77
  task_type="CAUSAL_LM",
78
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
79
  )
80
-
81
  model = get_peft_model(model, peft_config)
82
 
83
  tokenizer = AutoTokenizer.from_pretrained(
 
46
  st.info("Loading the model, please wait...")
47
  model_name = model_links[selected_model_name]
48
 
49
+ # Ensure the device is properly set for CUDA availability
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Load the model with quantization settings
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ trust_remote_code=True,
56
+ )
57
+
58
+ # Ensure every part of the model is assigned to the correct device
59
+ model.to(device) # This should correctly set devices for all components
60
+
61
+ # Additional configurations and training enhancements
62
  model.config.use_cache = False
63
  model = prepare_model_for_kbit_training(model)
64
 
65
+ # If using PEFT or other enhancements, configure here
66
  peft_config = LoraConfig(
67
  lora_alpha=16,
68
  lora_dropout=0.1,
 
71
  task_type="CAUSAL_LM",
72
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"],
73
  )
 
74
  model = get_peft_model(model, peft_config)
75
 
76
  tokenizer = AutoTokenizer.from_pretrained(