Vishwas1 commited on
Commit
8da495a
·
verified ·
1 Parent(s): 0b2d673

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +149 -0
train_model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training_space/train_model.py (Training Script)
2
+ import argparse
3
+ from transformers import (
4
+ GPT2Config, GPT2LMHeadModel,
5
+ BertConfig, BertForSequenceClassification,
6
+ Trainer, TrainingArguments, AutoTokenizer,
7
+ DataCollatorForLanguageModeling, DataCollatorWithPadding
8
+ )
9
+ from datasets import load_dataset, Dataset
10
+ import torch
11
+ import os
12
+ from huggingface_hub import HfApi, HfFolder
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--task", type=str, required=True, help="Task type: generation or classification")
17
+ parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
18
+ parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset")
19
+ parser.add_argument("--num_layers", type=int, default=12)
20
+ parser.add_argument("--attention_heads", type=int, default=1)
21
+ parser.add_argument("--hidden_size", type=int, default=64)
22
+ parser.add_argument("--vocab_size", type=int, default=30000)
23
+ parser.add_argument("--sequence_length", type=int, default=512)
24
+ args = parser.parse_args()
25
+
26
+ # Define output directory
27
+ output_dir = f"./models/{args.model_name}"
28
+ os.makedirs(output_dir, exist_ok=True)
29
+
30
+ # Initialize Hugging Face API
31
+ api = HfApi()
32
+ hf_token = HfFolder.get_token()
33
+
34
+ # Initialize tokenizer (adjust based on task)
35
+ if args.task == "generation":
36
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
37
+ elif args.task == "classification":
38
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
39
+ else:
40
+ raise ValueError("Unsupported task type")
41
+
42
+ # Load and prepare dataset
43
+ if args.task == "generation":
44
+ dataset = load_dataset('text', data_files={'train': args.dataset})
45
+ def tokenize_function(examples):
46
+ return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
47
+ elif args.task == "classification":
48
+ # For classification, assume the dataset is a simple text file with "text\tlabel" per line
49
+ with open(args.dataset, "r", encoding="utf-8") as f:
50
+ lines = f.readlines()
51
+ texts = []
52
+ labels = []
53
+ for line in lines:
54
+ parts = line.strip().split("\t")
55
+ if len(parts) == 2:
56
+ texts.append(parts[0])
57
+ labels.append(int(parts[1]))
58
+ dataset = Dataset.from_dict({"text": texts, "label": labels})
59
+ def tokenize_function(examples):
60
+ return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
61
+ else:
62
+ raise ValueError("Unsupported task type")
63
+
64
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
65
+
66
+ if args.task == "generation":
67
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
68
+ elif args.task == "classification":
69
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
70
+
71
+ # Initialize model based on task
72
+ if args.task == "generation":
73
+ config = GPT2Config(
74
+ vocab_size=args.vocab_size,
75
+ n_positions=args.sequence_length,
76
+ n_ctx=args.sequence_length,
77
+ n_embd=args.hidden_size,
78
+ num_hidden_layers=args.num_layers,
79
+ num_attention_heads=args.attention_heads,
80
+ intermediate_size=4 * args.hidden_size,
81
+ hidden_act='gelu',
82
+ use_cache=True
83
+ )
84
+ model = GPT2LMHeadModel(config)
85
+ elif args.task == "classification":
86
+ config = BertConfig(
87
+ vocab_size=args.vocab_size,
88
+ max_position_embeddings=args.sequence_length,
89
+ hidden_size=args.hidden_size,
90
+ num_hidden_layers=args.num_layers,
91
+ num_attention_heads=args.attention_heads,
92
+ intermediate_size=4 * args.hidden_size,
93
+ hidden_act='gelu',
94
+ num_labels=2 # Adjust based on your classification task
95
+ )
96
+ model = BertForSequenceClassification(config)
97
+ else:
98
+ raise ValueError("Unsupported task type")
99
+
100
+ # Define training arguments
101
+ if args.task == "generation":
102
+ training_args = TrainingArguments(
103
+ output_dir=output_dir,
104
+ num_train_epochs=3,
105
+ per_device_train_batch_size=8,
106
+ save_steps=5000,
107
+ save_total_limit=2,
108
+ logging_steps=500,
109
+ learning_rate=5e-4,
110
+ remove_unused_columns=False
111
+ )
112
+ elif args.task == "classification":
113
+ training_args = TrainingArguments(
114
+ output_dir=output_dir,
115
+ num_train_epochs=3,
116
+ per_device_train_batch_size=16,
117
+ evaluation_strategy="epoch",
118
+ save_steps=5000,
119
+ save_total_limit=2,
120
+ logging_steps=500,
121
+ learning_rate=5e-5,
122
+ remove_unused_columns=False
123
+ )
124
+
125
+ # Initialize Trainer
126
+ trainer = Trainer(
127
+ model=model,
128
+ args=training_args,
129
+ train_dataset=tokenized_datasets['train'],
130
+ data_collator=data_collator,
131
+ )
132
+
133
+ # Start training
134
+ trainer.train()
135
+
136
+ # Save the final model
137
+ trainer.save_model(output_dir)
138
+ tokenizer.save_pretrained(output_dir)
139
+
140
+ # Push to Hugging Face Hub
141
+ model_repo = f"your-username/{args.model_name}"
142
+ api.create_repo(repo_id=model_repo, private=False, token=hf_token)
143
+ model.push_to_hub(model_repo, use_auth_token=hf_token)
144
+ tokenizer.push_to_hub(model_repo, use_auth_token=hf_token)
145
+
146
+ print(f"Model '{args.model_name}' trained and pushed to Hugging Face Hub at '{model_repo}'.")
147
+
148
+ if __name__ == "__main__":
149
+ main()