poltextlab commited on
Commit
4e1f9ee
·
verified ·
1 Parent(s): 8b7617d

Delete binary_trainer_finetune_climate.py

Browse files
Files changed (1) hide show
  1. binary_trainer_finetune_climate.py +0 -198
binary_trainer_finetune_climate.py DELETED
@@ -1,198 +0,0 @@
1
- import os
2
- os.environ["WANDB_DISABLED"] = "true"
3
-
4
- import pandas as pd
5
- import numpy as np
6
- import torch
7
- import evaluate
8
-
9
- from sklearn.metrics import classification_report
10
- from datasets import Dataset
11
- from transformers import (
12
- AutoTokenizer,
13
- AutoModelForSequenceClassification,
14
- Trainer,
15
- TrainingArguments,
16
- EarlyStoppingCallback
17
- )
18
- from sklearn.metrics import precision_recall_curve, f1_score
19
-
20
- # --- Settings ---
21
- language_model = 'xlm-roberta-large'
22
- train_path = './data/climate/binary_train_illframes_climate.csv'
23
- val_path = './data/climate/binary_val_illframes_climate.csv'
24
- test_path = './data/climate/binary_test_illframes_climate.csv'
25
-
26
- lr = 5e-6
27
- batch_size = 8
28
- epochs = 5
29
- maxlen = 256
30
- output_dir = "./binary_model_output"
31
-
32
- data_train = pd.read_csv(train_path)
33
- data_val = pd.read_csv(val_path)
34
- data_test = pd.read_csv(test_path)
35
-
36
- def balance_dataframe(df):
37
- class_counts = df['label'].value_counts()
38
- min_class = class_counts.idxmin()
39
- max_class = class_counts.idxmax()
40
- n = class_counts.min()
41
-
42
- df_min = df[df['label'] == min_class]
43
- df_max = df[df['label'] == max_class].sample(n=n, random_state=42)
44
-
45
- return pd.concat([df_min, df_max]).sample(frac=1, random_state=42).reset_index(drop=True)
46
-
47
- val_bal = balance_dataframe(data_val)
48
- test_bal = balance_dataframe(data_test)
49
-
50
- # --- Label maps ---
51
- id2label = {0: "No_frame", 1: "Frame"}
52
- label2id = {v: k for k, v in id2label.items()}
53
-
54
- # --- Tokenizer ---
55
- tokenizer = AutoTokenizer.from_pretrained(language_model)
56
-
57
- def tokenize(batch):
58
- return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=maxlen)
59
-
60
- # --- Hugging Face Datasets ---
61
- train_ds = Dataset.from_pandas(data_train)
62
- test_ds = Dataset.from_pandas(data_test)
63
- test_bal_ds = Dataset.from_pandas(test_bal)
64
- val_ds = Dataset.from_pandas(data_val)
65
- val_bal_ds = Dataset.from_pandas(val_bal)
66
-
67
- train_ds = train_ds.map(tokenize, batched=True)
68
- test_ds = test_ds.map(tokenize, batched=True)
69
- test_bal_ds = test_bal_ds.map(tokenize, batched=True)
70
- val_ds = val_ds.map(tokenize, batched=True)
71
- val_bal_ds = val_bal_ds.map(tokenize, batched=True)
72
-
73
- # Remove unnecessary columns
74
- train_ds = train_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in train_ds.column_names else train_ds.remove_columns(["text"])
75
-
76
- test_ds = test_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in test_ds.column_names else test_ds.remove_columns(["text"])
77
- test_bal_ds = test_bal_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in test_bal_ds.column_names else test_bal_ds.remove_columns(["text"])
78
-
79
- val_ds = val_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in val_ds.column_names else val_ds.remove_columns(["text"])
80
- val_bal_ds = val_bal_ds.remove_columns(["text", "__index_level_0__"]) if "__index_level_0__" in val_bal_ds.column_names else val_bal_ds.remove_columns(["text"])
81
-
82
-
83
- # --- Model ---
84
- model = AutoModelForSequenceClassification.from_pretrained(
85
- language_model,
86
- num_labels=2,
87
- id2label=id2label,
88
- label2id=label2id
89
- )
90
-
91
- # --- Metrics ---
92
- def compute_metrics(eval_pred):
93
- metric = evaluate.load("f1")
94
- logits, labels = eval_pred
95
- preds = np.argmax(logits, axis=1)
96
- return metric.compute(predictions=preds, references=labels, average="weighted")
97
-
98
- # --- Trainer ---
99
- training_args = TrainingArguments(
100
- output_dir=output_dir,
101
- evaluation_strategy="epoch",
102
- save_strategy="epoch",
103
- learning_rate=lr,
104
- per_device_train_batch_size=batch_size,
105
- per_device_eval_batch_size=batch_size,
106
- num_train_epochs=epochs,
107
- weight_decay=0.01,
108
- logging_dir="./logs",
109
- load_best_model_at_end=True,
110
- save_total_limit=1,
111
- )
112
-
113
- trainer = Trainer(
114
- model=model,
115
- args=training_args,
116
- train_dataset=train_ds,
117
- eval_dataset=val_bal_ds,
118
- tokenizer=tokenizer,
119
- compute_metrics=compute_metrics,
120
- callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
121
- )
122
-
123
- # --- Train ---
124
- trainer.train()
125
-
126
- from transformers import AutoModelForSequenceClassification
127
- #import subprocess
128
- #subprocess.check_call(["pip", "install", "safetensors"])
129
- #import safetensors
130
-
131
- #model = trainer.model
132
-
133
- # Save using safetensors
134
- #model.save_pretrained(
135
- # f"{output_dir}/model_3",
136
- # safe_serialization=True
137
- #)
138
- trainer.save_model(f"{output_dir}/model_3")
139
-
140
- # --- Inference on unbalanced data ---
141
- val_outputs = trainer.predict(val_ds)
142
- val_probs = torch.softmax(torch.tensor(val_outputs.predictions), dim=1)[:, 1].numpy()
143
- val_labels = val_outputs.label_ids
144
-
145
- # Find best threshold based on val F1
146
- prec, rec, thresholds = precision_recall_curve(val_labels, val_probs)
147
- f1s = 2 * (prec * rec) / (prec + rec + 1e-8)
148
- best_thresh = thresholds[np.argmax(f1s)]
149
-
150
- print(f"Best threshold from validation: {best_thresh:.3f}")
151
-
152
- # --- Predict & Save with best threshold ---
153
-
154
- test_outputs = trainer.predict(test_ds)
155
- test_probs = torch.softmax(torch.tensor(test_outputs.predictions), dim=1)[:, 1].numpy()
156
- test_labels = test_outputs.label_ids
157
-
158
- # Apply threshold
159
- test_preds = (test_probs >= best_thresh).astype(int)
160
-
161
- # Save results
162
- test_results_df = data_test.copy()
163
- test_results_df['prob'] = test_probs
164
- test_results_df['pred'] = test_preds
165
-
166
- # Save classification report
167
- cr = classification_report(test_labels, test_preds, output_dict=True)
168
- pd.DataFrame(cr).transpose().to_csv("threshold_classification_report.csv")
169
- print(cr)
170
-
171
- # Save predictions
172
- test_results_df.to_csv("threshold_test_predictions.csv", index=False)
173
-
174
-
175
-
176
-
177
-
178
-
179
- #preds = trainer.predict(test_bal_ds)
180
- #preds = trainer.predict(test_ds)
181
- #test_preds = np.argmax(preds.predictions, axis=1)
182
-
183
- # Add to DataFrame
184
- #test_bal['pred'] = test_preds
185
- #data_test['pred'] = test_preds
186
-
187
- # Save prediction and classification report
188
- #data_test.to_csv("./binary_results.csv", index=False)
189
-
190
- #report = classification_report(data_test["label"], data_test["pred"], output_dict=True)
191
- #report_df = pd.DataFrame(report).transpose()
192
- #report_df.to_csv("./binary_classification_report.csv")
193
-
194
- #report = classification_report(test_bal["label"], test_bal["pred"], output_dict=True)
195
- #report_df = pd.DataFrame(report).transpose()
196
- #report_df.to_csv("./binary_classification_report.csv")
197
-
198
- #print(report_df)