Azperia commited on
Commit
9873cb7
·
verified ·
1 Parent(s): 8adf38c

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +47 -0
train_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
2
+
3
+ # Charger le modèle et le tokenizer GPT-2
4
+ model_name = "gpt2"
5
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
6
+ model = GPT2LMHeadModel.from_pretrained(model_name)
7
+
8
+ # Préparer ton dataset (assure-toi que 'train.txt' existe avec tes données)
9
+ def load_dataset(file_path, tokenizer, block_size=128):
10
+ return TextDataset(
11
+ tokenizer=tokenizer,
12
+ file_path=file_path,
13
+ block_size=block_size
14
+ )
15
+
16
+ # Charger le dataset
17
+ train_dataset = load_dataset("train.txt", tokenizer)
18
+
19
+ # Préparer les arguments pour l'entraînement
20
+ training_args = TrainingArguments(
21
+ output_dir="./thought_model",
22
+ overwrite_output_dir=True,
23
+ num_train_epochs=3,
24
+ per_device_train_batch_size=2,
25
+ save_steps=500,
26
+ save_total_limit=2
27
+ )
28
+
29
+ # Préparer le data collator
30
+ data_collator = DataCollatorForLanguageModeling(
31
+ tokenizer=tokenizer,
32
+ mlm=False
33
+ )
34
+
35
+ # Lancer l'entraînement
36
+ trainer = Trainer(
37
+ model=model,
38
+ args=training_args,
39
+ data_collator=data_collator,
40
+ train_dataset=train_dataset
41
+ )
42
+
43
+ trainer.train()
44
+
45
+ # Sauvegarder le modèle fine-tuné
46
+ trainer.save_model("./thought_model")
47
+ tokenizer.save_pretrained("./thought_model")