Update app.py
Browse files
app.py
CHANGED
@@ -85,19 +85,19 @@ def load_model(model_type, selected_model):
|
|
85 |
|
86 |
if model_type == "Full Fine-Tuned":
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
)
|
94 |
|
95 |
#model = AutoModelForCausalLM.from_pretrained(
|
96 |
#selected_model,
|
97 |
# torch_dtype=torch.bfloat16,
|
98 |
# device_map="auto",
|
99 |
# token=HF_TOKEN
|
100 |
-
#
|
|
|
101 |
else:
|
102 |
base_model = AutoModelForCausalLM.from_pretrained(
|
103 |
BASE_MODEL_NAME,
|
|
|
85 |
|
86 |
if model_type == "Full Fine-Tuned":
|
87 |
|
88 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
89 |
+
selected_model,
|
90 |
+
torch_dtype=torch.bfloat16, # or float32 for compatibility
|
91 |
+
token=HF_TOKEN
|
92 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
93 |
|
94 |
#model = AutoModelForCausalLM.from_pretrained(
|
95 |
#selected_model,
|
96 |
# torch_dtype=torch.bfloat16,
|
97 |
# device_map="auto",
|
98 |
# token=HF_TOKEN
|
99 |
+
#
|
100 |
+
|
101 |
else:
|
102 |
base_model = AutoModelForCausalLM.from_pretrained(
|
103 |
BASE_MODEL_NAME,
|