gated model updates
Browse files
app.py
CHANGED
@@ -26,11 +26,13 @@ lm_model = AutoModelForCausalLM.from_pretrained(
|
|
26 |
device_map="auto"
|
27 |
)
|
28 |
|
29 |
-
# Load the reward model
|
|
|
30 |
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
|
31 |
'ray24724919/plan2align_rm',
|
32 |
torch_dtype=torch_dtype,
|
33 |
-
device_map="
|
|
|
34 |
)
|
35 |
RM.eval()
|
36 |
print("Models loaded successfully!")
|
|
|
26 |
device_map="auto"
|
27 |
)
|
28 |
|
29 |
+
# Load the reward model - fix the offloading issue
|
30 |
+
print("Loading reward model...")
|
31 |
RM = AutoModelForCausalLMWithValueHead.from_pretrained(
|
32 |
'ray24724919/plan2align_rm',
|
33 |
torch_dtype=torch_dtype,
|
34 |
+
device_map={"": 0}, # Force model to stay on GPU (device 0)
|
35 |
+
offload_folder=None, # Disable offloading
|
36 |
)
|
37 |
RM.eval()
|
38 |
print("Models loaded successfully!")
|