Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -49,24 +49,30 @@ st.sidebar.write(f"You're now chatting with **{selected_model}**")
|
|
49 |
st.sidebar.markdown(model_info[selected_model]["description"])
|
50 |
st.sidebar.image(model_info[selected_model]["logo"])
|
51 |
|
52 |
-
# Load the appropriate model
|
53 |
def load_model():
|
54 |
model_name = model_links["HAH-2024-v0.1"]
|
55 |
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
56 |
|
57 |
-
# Load model with quantization
|
58 |
bnb_config = BitsAndBytesConfig(
|
59 |
load_in_4bit=True,
|
60 |
bnb_4bit_quant_type="nf4",
|
61 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
62 |
bnb_4bit_use_double_quant=False,
|
|
|
63 |
)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
model = AutoModelForCausalLM.from_pretrained(
|
66 |
model_name,
|
67 |
quantization_config=bnb_config,
|
68 |
torch_dtype=torch.bfloat16,
|
69 |
-
device_map=
|
70 |
trust_remote_code=True,
|
71 |
)
|
72 |
|
@@ -88,7 +94,6 @@ def load_model():
|
|
88 |
|
89 |
return model, tokenizer
|
90 |
|
91 |
-
model, tokenizer = load_model()
|
92 |
|
93 |
# Initialize chat history
|
94 |
if "messages" not in st.session_state:
|
|
|
49 |
st.sidebar.markdown(model_info[selected_model]["description"])
|
50 |
st.sidebar.image(model_info[selected_model]["logo"])
|
51 |
|
|
|
52 |
def load_model():
|
53 |
model_name = model_links["HAH-2024-v0.1"]
|
54 |
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
55 |
|
56 |
+
# Load model with quantization and device map configurations
|
57 |
bnb_config = BitsAndBytesConfig(
|
58 |
load_in_4bit=True,
|
59 |
bnb_4bit_quant_type="nf4",
|
60 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
61 |
bnb_4bit_use_double_quant=False,
|
62 |
+
llm_int8_enable_fp32_cpu_offload=True # Enable CPU offloading for certain parts
|
63 |
)
|
64 |
|
65 |
+
# Custom device map to manage resource utilization
|
66 |
+
device_map = {
|
67 |
+
'encoder': 'cuda', # Keep encoder on GPU
|
68 |
+
'decoder': 'cpu', # Offload decoder to CPU if GPU RAM is insufficient
|
69 |
+
}
|
70 |
+
|
71 |
model = AutoModelForCausalLM.from_pretrained(
|
72 |
model_name,
|
73 |
quantization_config=bnb_config,
|
74 |
torch_dtype=torch.bfloat16,
|
75 |
+
device_map=device_map, # Apply custom device map
|
76 |
trust_remote_code=True,
|
77 |
)
|
78 |
|
|
|
94 |
|
95 |
return model, tokenizer
|
96 |
|
|
|
97 |
|
98 |
# Initialize chat history
|
99 |
if "messages" not in st.session_state:
|