yjgjhgjh commited on
Commit
14279ce
1 Parent(s): 4d210e0

Create Train your own codexchan checkpoint if you prefer using this.py

Browse files
Train your own codexchan checkpoint if you prefer using this.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #this script will let you train your own distillgpt checkpoint or fine tune the one in checkpoint-4000
2
+ import os
3
+ import torch
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling
5
+ from transformers import Trainer, TrainingArguments, TrainerCallback # Added TrainerCallback here
6
+ from datasets import load_dataset
7
+ from datetime import datetime
8
+
9
+ # Data preparation
10
+ data_dir = r"https://github.com/zrebarchak/Codexchan.exe-Archive"
11
+ #
12
+ """replace this with folder of txt files
13
+ the github link This is the base dataset. it includes all of codexchan's videos where they spoke.
14
+ theres nothing wrong with the errored folder, you should combine it-
15
+ and train on them both fom . note that this dataset doesnt include the faq
16
+ (https://etherpad.mit.edu/p/r.46c0a7842e569d53dc22b44afed6bc40)
17
+ or this https://www.onlinegdb.com/fork/IrQRJkyX0
18
+ also note checkpoint-4000 was not trained on these either, just this base dataset. have fun!"""
19
+ #
20
+ dataset = load_dataset("text", data_files=os.path.join(data_dir, "*.txt"))
21
+
22
+ # Model and tokenizer setup
23
+ model_name = "distilgpt2"
24
+ base_output_dir = "./distilgpt2-fine-tuned"
25
+
26
+ # Generate a unique name for this training run
27
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
28
+ output_dir = os.path.join(base_output_dir, f"distilgpt2_continuous_{current_time}")
29
+
30
+ # Function to find the most recent model directory
31
+ def find_most_recent_model(base_dir):
32
+ if not os.path.exists(base_dir):
33
+ return None
34
+ subdirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
35
+ valid_dirs = [d for d in subdirs if os.path.exists(os.path.join(d, 'config.json'))]
36
+ return max(valid_dirs, key=os.path.getmtime) if valid_dirs else None
37
+
38
+ most_recent_dir = find_most_recent_model(base_output_dir)
39
+
40
+ if most_recent_dir:
41
+ print(f"Loading most recent saved model from: {most_recent_dir}")
42
+ try:
43
+ model = GPT2LMHeadModel.from_pretrained(most_recent_dir)
44
+ tokenizer = GPT2Tokenizer.from_pretrained(most_recent_dir)
45
+ except Exception as e:
46
+ print(f"Error loading saved model: {e}")
47
+ print("Starting with fresh model instead.")
48
+ model = GPT2LMHeadModel.from_pretrained(model_name)
49
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
50
+ else:
51
+ print("No valid saved model found. Starting with fresh model...")
52
+ model = GPT2LMHeadModel.from_pretrained(model_name)
53
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
54
+
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+ model.config.pad_token_id = model.config.eos_token_id
57
+
58
+ # Tokenize the dataset
59
+ def tokenize_function(examples):
60
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
61
+
62
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
63
+
64
+ # Training arguments
65
+ training_args = TrainingArguments(
66
+ output_dir=output_dir,
67
+ overwrite_output_dir=True,
68
+ per_device_train_batch_size=1,
69
+ gradient_accumulation_steps=4,
70
+ save_steps=1000,
71
+ save_total_limit=5,
72
+ fp16=True,
73
+ gradient_checkpointing=True,
74
+ learning_rate=1e-4,
75
+ warmup_steps=100,
76
+ logging_steps=10, # Log more frequently
77
+ max_steps=-1, # No limit on the number of steps
78
+ num_train_epochs=215, # This will be ignored due to max_steps=-1
79
+ )
80
+
81
+ # Custom callback to print progress
82
+ class ProgressCallback(TrainerCallback):
83
+ def __init__(self, total_steps=1000000): # A large number, but not so large it causes display issues
84
+ self.total_steps = total_steps
85
+
86
+ def on_log(self, args, state, control, logs=None, **kwargs):
87
+ if state.global_step % 10 == 0: # Print every 10 steps
88
+ print(f"Step: {state.global_step}/{self.total_steps} - Loss: {logs.get('loss', 'N/A'):.4f}")
89
+
90
+ # Trainer setup
91
+ trainer = Trainer(
92
+ model=model,
93
+ args=training_args,
94
+ train_dataset=tokenized_dataset["train"],
95
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
96
+ callbacks=[ProgressCallback()]
97
+ )
98
+
99
+ # Enable gradient checkpointing
100
+ model.gradient_checkpointing_enable()
101
+
102
+ # Start training
103
+ print(f"Starting long-running training. Models will be saved to {output_dir}")
104
+ print("Press Ctrl+C to stop...")
105
+ try:
106
+ trainer.train()
107
+ except KeyboardInterrupt:
108
+ print("\nTraining interrupted. Saving model...")
109
+ trainer.save_model()
110
+ print(f"Model saved to {output_dir}. You can resume training later by running this script again.")
111
+
112
+ print("Training completed or interrupted. Final model saved.")