Delete binary_trainer_finetune_climate.py
Browse files
binary_trainer_finetune_climate.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
os.environ["WANDB_DISABLED"] = "true"
|
3 |
-
|
4 |
-
import pandas as pd
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import evaluate
|
8 |
-
|
9 |
-
from sklearn.metrics import classification_report
|
10 |
-
from datasets import Dataset
|
11 |
-
from transformers import (
|
12 |
-
AutoTokenizer,
|
13 |
-
AutoModelForSequenceClassification,
|
14 |
-
Trainer,
|
15 |
-
TrainingArguments,
|
16 |
-
EarlyStoppingCallback
|
17 |
-
)
|
18 |
-
from sklearn.metrics import precision_recall_curve, f1_score
|
19 |
-
|
20 |
-
# --- Settings ---
|
21 |
-
language_model = 'xlm-roberta-large'
|
22 |
-
train_path = './data/climate/binary_train_illframes_climate.csv'
|
23 |
-
val_path = './data/climate/binary_val_illframes_climate.csv'
|
24 |
-
test_path = './data/climate/binary_test_illframes_climate.csv'
|
25 |
-
|
26 |
-
lr = 5e-6
|
27 |
-
batch_size = 8
|
28 |
-
epochs = 5
|
29 |
-
maxlen = 256
|
30 |
-
output_dir = "./binary_model_output"
|
31 |
-
|
32 |
-
data_train = pd.read_csv(train_path)
|
33 |
-
data_val = pd.read_csv(val_path)
|
34 |
-
data_test = pd.read_csv(test_path)
|
35 |
-
|
36 |
-
def balance_dataframe(df):
|
37 |
-
class_counts = df['label'].value_counts()
|
38 |
-
min_class = class_counts.idxmin()
|
39 |
-
max_class = class_counts.idxmax()
|
40 |
-
n = class_counts.min()
|
41 |
-
|
42 |
-
df_min = df[df['label'] == min_class]
|
43 |
-
df_max = df[df['label'] == max_class].sample(n=n, random_state=42)
|
44 |
-
|
45 |
-
return pd.concat([df_min, df_max]).sample(frac=1, random_state=42).reset_index(drop=True)
|
46 |
-
|
47 |
-
val_bal = balance_dataframe(data_val)
|
48 |
-
test_bal = balance_dataframe(data_test)
|
49 |
-
|
50 |
-
# --- Label maps ---
|
51 |
-
id2label = {0: "No_frame", 1: "Frame"}
|
52 |
-
label2id = {v: k for k, v in id2label.items()}
|
53 |
-
|
54 |
-
# --- Tokenizer ---
|
55 |
-
tokenizer = AutoTokenizer.from_pretrained(language_model)
|
56 |
-
|
57 |
-
def tokenize(batch):
|
58 |
-
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=maxlen)
|
59 |
-
|
60 |
-
# --- Hugging Face Datasets ---
|
61 |
-
train_ds = Dataset.from_pandas(data_train)
|
62 |
-
test_ds = Dataset.from_pandas(data_test)
|
63 |
-
test_bal_ds = Dataset.from_pandas(test_bal)
|
64 |
-
val_ds = Dataset.from_pandas(data_val)
|
65 |
-
val_bal_ds = Dataset.from_pandas(val_bal)
|
66 |
-
|
67 |
-
train_ds = train_ds.map(tokenize, batched=True)
|
68 |
-
test_ds = test_ds.map(tokenize, batched=True)
|
69 |
-
test_bal_ds = test_bal_ds.map(tokenize, batched=True)
|
70 |
-
val_ds = val_ds.map(tokenize, batched=True)
|
71 |
-
val_bal_ds = val_bal_ds.map(tokenize, batched=True)
|
72 |
-
|
73 |
-
# Remove unnecessary columns
|
74 |
-
train_ds = train_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in train_ds.column_names else train_ds.remove_columns(["text"])
|
75 |
-
|
76 |
-
test_ds = test_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in test_ds.column_names else test_ds.remove_columns(["text"])
|
77 |
-
test_bal_ds = test_bal_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in test_bal_ds.column_names else test_bal_ds.remove_columns(["text"])
|
78 |
-
|
79 |
-
val_ds = val_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in val_ds.column_names else val_ds.remove_columns(["text"])
|
80 |
-
val_bal_ds = val_bal_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in val_bal_ds.column_names else val_bal_ds.remove_columns(["text"])
|
81 |
-
|
82 |
-
|
83 |
-
# --- Model ---
|
84 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
85 |
-
language_model,
|
86 |
-
num_labels=2,
|
87 |
-
id2label=id2label,
|
88 |
-
label2id=label2id
|
89 |
-
)
|
90 |
-
|
91 |
-
# --- Metrics ---
|
92 |
-
def compute_metrics(eval_pred):
|
93 |
-
metric = evaluate.load("f1")
|
94 |
-
logits, labels = eval_pred
|
95 |
-
preds = np.argmax(logits, axis=1)
|
96 |
-
return metric.compute(predictions=preds, references=labels, average="weighted")
|
97 |
-
|
98 |
-
# --- Trainer ---
|
99 |
-
training_args = TrainingArguments(
|
100 |
-
output_dir=output_dir,
|
101 |
-
evaluation_strategy="epoch",
|
102 |
-
save_strategy="epoch",
|
103 |
-
learning_rate=lr,
|
104 |
-
per_device_train_batch_size=batch_size,
|
105 |
-
per_device_eval_batch_size=batch_size,
|
106 |
-
num_train_epochs=epochs,
|
107 |
-
weight_decay=0.01,
|
108 |
-
logging_dir="./logs",
|
109 |
-
load_best_model_at_end=True,
|
110 |
-
save_total_limit=1,
|
111 |
-
)
|
112 |
-
|
113 |
-
trainer = Trainer(
|
114 |
-
model=model,
|
115 |
-
args=training_args,
|
116 |
-
train_dataset=train_ds,
|
117 |
-
eval_dataset=val_bal_ds,
|
118 |
-
tokenizer=tokenizer,
|
119 |
-
compute_metrics=compute_metrics,
|
120 |
-
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
|
121 |
-
)
|
122 |
-
|
123 |
-
# --- Train ---
|
124 |
-
trainer.train()
|
125 |
-
|
126 |
-
from transformers import AutoModelForSequenceClassification
|
127 |
-
#import subprocess
|
128 |
-
#subprocess.check_call(["pip", "install", "safetensors"])
|
129 |
-
#import safetensors
|
130 |
-
|
131 |
-
#model = trainer.model
|
132 |
-
|
133 |
-
# Save using safetensors
|
134 |
-
#model.save_pretrained(
|
135 |
-
# f"{output_dir}/model_3",
|
136 |
-
# safe_serialization=True
|
137 |
-
#)
|
138 |
-
trainer.save_model(f"{output_dir}/model_3")
|
139 |
-
|
140 |
-
# --- Inference on unbalanced data ---
|
141 |
-
val_outputs = trainer.predict(val_ds)
|
142 |
-
val_probs = torch.softmax(torch.tensor(val_outputs.predictions), dim=1)[:, 1].numpy()
|
143 |
-
val_labels = val_outputs.label_ids
|
144 |
-
|
145 |
-
# Find best threshold based on val F1
|
146 |
-
prec, rec, thresholds = precision_recall_curve(val_labels, val_probs)
|
147 |
-
f1s = 2 * (prec * rec) / (prec + rec + 1e-8)
|
148 |
-
best_thresh = thresholds[np.argmax(f1s)]
|
149 |
-
|
150 |
-
print(f"Best threshold from validation: {best_thresh:.3f}")
|
151 |
-
|
152 |
-
# --- Predict & Save with best threshold ---
|
153 |
-
|
154 |
-
test_outputs = trainer.predict(test_ds)
|
155 |
-
test_probs = torch.softmax(torch.tensor(test_outputs.predictions), dim=1)[:, 1].numpy()
|
156 |
-
test_labels = test_outputs.label_ids
|
157 |
-
|
158 |
-
# Apply threshold
|
159 |
-
test_preds = (test_probs >= best_thresh).astype(int)
|
160 |
-
|
161 |
-
# Save results
|
162 |
-
test_results_df = data_test.copy()
|
163 |
-
test_results_df['prob'] = test_probs
|
164 |
-
test_results_df['pred'] = test_preds
|
165 |
-
|
166 |
-
# Save classification report
|
167 |
-
cr = classification_report(test_labels, test_preds, output_dict=True)
|
168 |
-
pd.DataFrame(cr).transpose().to_csv("threshold_classification_report.csv")
|
169 |
-
print(cr)
|
170 |
-
|
171 |
-
# Save predictions
|
172 |
-
test_results_df.to_csv("threshold_test_predictions.csv", index=False)
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
#preds = trainer.predict(test_bal_ds)
|
180 |
-
#preds = trainer.predict(test_ds)
|
181 |
-
#test_preds = np.argmax(preds.predictions, axis=1)
|
182 |
-
|
183 |
-
# Add to DataFrame
|
184 |
-
#test_bal['pred'] = test_preds
|
185 |
-
#data_test['pred'] = test_preds
|
186 |
-
|
187 |
-
# Save prediction and classification report
|
188 |
-
#data_test.to_csv("./binary_results.csv", index=False)
|
189 |
-
|
190 |
-
#report = classification_report(data_test["label"], data_test["pred"], output_dict=True)
|
191 |
-
#report_df = pd.DataFrame(report).transpose()
|
192 |
-
#report_df.to_csv("./binary_classification_report.csv")
|
193 |
-
|
194 |
-
#report = classification_report(test_bal["label"], test_bal["pred"], output_dict=True)
|
195 |
-
#report_df = pd.DataFrame(report).transpose()
|
196 |
-
#report_df.to_csv("./binary_classification_report.csv")
|
197 |
-
|
198 |
-
#print(report_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|