drmasad commited on
Commit
8dbaa52
·
verified ·
1 Parent(s): ee59722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -49,24 +49,30 @@ st.sidebar.write(f"You're now chatting with **{selected_model}**")
49
  st.sidebar.markdown(model_info[selected_model]["description"])
50
  st.sidebar.image(model_info[selected_model]["logo"])
51
 
52
- # Load the appropriate model
53
  def load_model():
54
  model_name = model_links["HAH-2024-v0.1"]
55
  base_model = "mistralai/Mistral-7B-Instruct-v0.2"
56
 
57
- # Load model with quantization configuration
58
  bnb_config = BitsAndBytesConfig(
59
  load_in_4bit=True,
60
  bnb_4bit_quant_type="nf4",
61
  bnb_4bit_compute_dtype=torch.bfloat16,
62
  bnb_4bit_use_double_quant=False,
 
63
  )
64
 
 
 
 
 
 
 
65
  model = AutoModelForCausalLM.from_pretrained(
66
  model_name,
67
  quantization_config=bnb_config,
68
  torch_dtype=torch.bfloat16,
69
- device_map="auto",
70
  trust_remote_code=True,
71
  )
72
 
@@ -88,7 +94,6 @@ def load_model():
88
 
89
  return model, tokenizer
90
 
91
- model, tokenizer = load_model()
92
 
93
  # Initialize chat history
94
  if "messages" not in st.session_state:
 
49
  st.sidebar.markdown(model_info[selected_model]["description"])
50
  st.sidebar.image(model_info[selected_model]["logo"])
51
 
 
52
  def load_model():
53
  model_name = model_links["HAH-2024-v0.1"]
54
  base_model = "mistralai/Mistral-7B-Instruct-v0.2"
55
 
56
+ # Load model with quantization and device map configurations
57
  bnb_config = BitsAndBytesConfig(
58
  load_in_4bit=True,
59
  bnb_4bit_quant_type="nf4",
60
  bnb_4bit_compute_dtype=torch.bfloat16,
61
  bnb_4bit_use_double_quant=False,
62
+ llm_int8_enable_fp32_cpu_offload=True # Enable CPU offloading for certain parts
63
  )
64
 
65
+ # Custom device map to manage resource utilization
66
+ device_map = {
67
+ 'encoder': 'cuda', # Keep encoder on GPU
68
+ 'decoder': 'cpu', # Offload decoder to CPU if GPU RAM is insufficient
69
+ }
70
+
71
  model = AutoModelForCausalLM.from_pretrained(
72
  model_name,
73
  quantization_config=bnb_config,
74
  torch_dtype=torch.bfloat16,
75
+ device_map=device_map, # Apply custom device map
76
  trust_remote_code=True,
77
  )
78
 
 
94
 
95
  return model, tokenizer
96
 
 
97
 
98
  # Initialize chat history
99
  if "messages" not in st.session_state: