Julien Simon commited on
Commit
b536abf
·
1 Parent(s): ead8a38

Training in progress, epoch 3

Browse files
Files changed (2) hide show
  1. pytorch_model.bin +1 -1
  2. train-xlm.py +114 -0
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b99978a873521a77c36ee174badbc6198d2fbd242c9618c12a524f45b64a14d2
3
  size 3114359925
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d62f1c5ab88bf2f7b3820b4b411f1b51a423796b4c6ad6fa37f8e21629d5c28d
3
  size 3114359925
train-xlm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ Trainer,
8
+ TrainingArguments,
9
+ )
10
+
11
+ dataset_id = "google/fleurs"
12
+ model_id = "facebook/xlm-v-base"
13
+ metric_name = "accuracy"
14
+
15
+ # Keep only the raw transcription and the language id (which we'll use as label)
16
+ columns_to_remove = [
17
+ "audio",
18
+ "id",
19
+ "num_samples",
20
+ "path",
21
+ "transcription",
22
+ "gender",
23
+ "language",
24
+ "lang_group_id",
25
+ ]
26
+
27
+ train, val = load_dataset(
28
+ dataset_id, "all", split=["train", "validation"], ignore_verifications=True
29
+ )
30
+
31
+ # Build the label2id and id2label dictionaries
32
+
33
+ unique_langs = set()
34
+ label2id = {}
35
+ id2label = {}
36
+ for lang, lang_id in zip(val["language"], val["lang_id"]):
37
+ if lang not in unique_langs:
38
+ unique_langs.add(lang)
39
+ id2label[lang_id] = lang
40
+ label2id[lang] = lang_id
41
+
42
+ id2label = dict(sorted(id2label.items(), key=lambda item: item[0]))
43
+ label2id = dict(sorted(label2id.items(), key=lambda item: item[1]))
44
+
45
+ train = train.remove_columns(columns_to_remove)
46
+ val = val.remove_columns(columns_to_remove)
47
+ train = train.rename_column("raw_transcription", "text")
48
+ val = val.rename_column("raw_transcription", "text")
49
+ train = train.rename_column("lang_id", "label")
50
+ val = val.rename_column("lang_id", "label")
51
+
52
+ train = train.shuffle(seed=42)
53
+ val = val.shuffle(seed=42)
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
56
+
57
+
58
+ def preprocess(data):
59
+ return tokenizer(data["text"], truncation=True)
60
+
61
+
62
+ processed_train = train.map(preprocess, batched=True)
63
+ processed_val = val.map(preprocess, batched=True)
64
+
65
+ print(processed_train)
66
+ print(processed_val)
67
+
68
+ # Fine-tune the model
69
+
70
+ model = AutoModelForSequenceClassification.from_pretrained(
71
+ model_id,
72
+ num_labels=len(id2label),
73
+ label2id=label2id,
74
+ id2label=id2label,
75
+ ignore_mismatched_sizes=True,
76
+ )
77
+
78
+ args = TrainingArguments(
79
+ "xlm-v-base-language-id",
80
+ learning_rate=3e-5,
81
+ warmup_ratio=0.1,
82
+ per_device_train_batch_size=16,
83
+ gradient_accumulation_steps=4,
84
+ per_device_eval_batch_size=16,
85
+ num_train_epochs=5,
86
+ load_best_model_at_end=True,
87
+ metric_for_best_model=metric_name,
88
+ evaluation_strategy="epoch",
89
+ save_strategy="epoch",
90
+ logging_steps=10,
91
+ fp16=True,
92
+ push_to_hub=True,
93
+ )
94
+
95
+ metric = evaluate.load(metric_name)
96
+
97
+
98
+ def compute_metrics(eval_pred):
99
+ predictions = np.argmax(eval_pred.predictions, axis=1)
100
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
101
+
102
+
103
+ trainer = Trainer(
104
+ model,
105
+ args,
106
+ train_dataset=processed_train,
107
+ eval_dataset=processed_val,
108
+ tokenizer=tokenizer,
109
+ compute_metrics=compute_metrics,
110
+ )
111
+
112
+ trainer.train()
113
+
114
+ trainer.save_model("./my_model")