# train_model.py (Training Script) import argparse from transformers import ( GPT2Config, GPT2LMHeadModel, BertConfig, BertForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, DataCollatorForLanguageModeling, DataCollatorWithPadding, ) from datasets import load_dataset import torch import os from huggingface_hub import login, HfApi, HfFolder import logging def setup_logging(log_file_path): """ Sets up logging to both console and a file. """ logger = logging.getLogger() logger.setLevel(logging.INFO) # Create handlers c_handler = logging.StreamHandler() f_handler = logging.FileHandler(log_file_path) c_handler.setLevel(logging.INFO) f_handler.setLevel(logging.INFO) # Create formatters and add to handlers c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') c_handler.setFormatter(c_format) f_handler.setFormatter(f_format) # Add handlers to the logger logger.addHandler(c_handler) logger.addHandler(f_handler) def parse_arguments(): """ Parses command-line arguments. """ parser = argparse.ArgumentParser(description="Train a custom LLM.") parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"], help="Task type: 'generation' or 'classification'") parser.add_argument("--model_name", type=str, required=True, help="Name of the model") parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'wikitext/wikitext-2-raw-v1')") parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers") parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads") parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model") parser.add_argument("--vocab_size", type=int, default=30000, help="Vocabulary size") parser.add_argument("--sequence_length", type=int, default=512, help="Maximum sequence length") args = parser.parse_args() return args def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length): """ Loads and tokenizes the dataset based on the task. """ logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...") try: if task == "generation": # Check if dataset_name includes config if '/' in dataset_name: dataset, config = dataset_name.split('/', 1) dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True) else: dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True) logging.info("Dataset loaded successfully for generation task.") def tokenize_function(examples): return tokenizer(examples['text'], truncation=True, max_length=sequence_length) elif task == "classification": if '/' in dataset_name: dataset, config = dataset_name.split('/', 1) dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True) else: dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True) logging.info("Dataset loaded successfully for classification task.") # Assuming the dataset has 'text' and 'label' columns def tokenize_function(examples): return tokenizer(examples['text'], truncation=True, max_length=sequence_length) else: raise ValueError("Unsupported task type") # Shuffle and select a subset tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True) logging.info("Dataset tokenization complete.") return tokenized_datasets except Exception as e: logging.error(f"Error loading or tokenizing dataset: {str(e)}") raise e def initialize_model(task, model_name, vocab_size, sequence_length, hidden_size, num_layers, attention_heads): """ Initializes the model configuration and model based on the task. """ logging.info(f"Initializing model for task '{task}'...") try: if task == "generation": config = GPT2Config( vocab_size=vocab_size, n_positions=sequence_length, n_ctx=sequence_length, n_embd=hidden_size, num_hidden_layers=num_layers, num_attention_heads=attention_heads, intermediate_size=4 * hidden_size, hidden_act='gelu', use_cache=True ) model = GPT2LMHeadModel(config) logging.info("GPT2LMHeadModel initialized successfully.") elif task == "classification": config = BertConfig( vocab_size=vocab_size, max_position_embeddings=sequence_length, hidden_size=hidden_size, num_hidden_layers=num_layers, num_attention_heads=attention_heads, intermediate_size=4 * hidden_size, hidden_act='gelu', num_labels=2 # Adjust based on your classification task ) model = BertForSequenceClassification(config) logging.info("BertForSequenceClassification initialized successfully.") else: raise ValueError("Unsupported task type") return model except Exception as e: logging.error(f"Error initializing model: {str(e)}") raise e def main(): # Parse arguments args = parse_arguments() # Setup logging log_file = "training.log" setup_logging(log_file) logging.info("Training script started.") # Initialize Hugging Face API api = HfApi() # Retrieve the Hugging Face API token from environment variables hf_token = os.getenv("HF_API_TOKEN") if not hf_token: logging.error("HF_API_TOKEN environment variable not set.") raise ValueError("HF_API_TOKEN environment variable not set.") # Perform login using the API token try: login(token=hf_token) logging.info("Successfully logged in to Hugging Face Hub.") except Exception as e: logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}") raise e # Initialize tokenizer try: logging.info("Initializing tokenizer...") if args.task == "generation": tokenizer = AutoTokenizer.from_pretrained("gpt2") elif args.task == "classification": tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") else: raise ValueError("Unsupported task type") logging.info("Tokenizer initialized successfully.") except Exception as e: logging.error(f"Error initializing tokenizer: {str(e)}") raise e # Load and prepare dataset try: tokenized_datasets = load_and_prepare_dataset( task=args.task, dataset_name=args.dataset_name, tokenizer=tokenizer, sequence_length=args.sequence_length ) except Exception as e: logging.error("Failed to load and prepare dataset.") raise e # Initialize model try: model = initialize_model( task=args.task, model_name=args.model_name, vocab_size=args.vocab_size, sequence_length=args.sequence_length, hidden_size=args.hidden_size, num_layers=args.num_layers, attention_heads=args.attention_heads ) except Exception as e: logging.error("Failed to initialize model.") raise e # Define data collator if args.task == "generation": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) elif args.task == "classification": data_collator = DataCollatorWithPadding(tokenizer=tokenizer) else: logging.error("Unsupported task type for data collator.") raise ValueError("Unsupported task type for data collator.") # Define training arguments if args.task == "generation": training_args = TrainingArguments( output_dir=f"./models/{args.model_name}", num_train_epochs=3, per_device_train_batch_size=8, save_steps=5000, save_total_limit=2, logging_steps=500, learning_rate=5e-4, remove_unused_columns=False, push_to_hub=False # We'll handle pushing manually ) elif args.task == "classification": training_args = TrainingArguments( output_dir=f"./models/{args.model_name}", num_train_epochs=3, per_device_train_batch_size=16, evaluation_strategy="epoch", save_steps=5000, save_total_limit=2, logging_steps=500, learning_rate=5e-5, remove_unused_columns=False, push_to_hub=False # We'll handle pushing manually ) else: logging.error("Unsupported task type for training arguments.") raise ValueError("Unsupported task type for training arguments.") # Initialize Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets, data_collator=data_collator, ) # Start training logging.info("Starting training...") try: trainer.train() logging.info("Training completed successfully.") except Exception as e: logging.error(f"Error during training: {str(e)}") raise e # Save the final model and tokenizer try: trainer.save_model(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) logging.info(f"Model and tokenizer saved to '{training_args.output_dir}'.") except Exception as e: logging.error(f"Error saving model or tokenizer: {str(e)}") raise e # Push the model to Hugging Face Hub model_repo = f"{api.whoami(token=hf_token)['name']}/{args.model_name}" try: logging.info(f"Pushing model to Hugging Face Hub at '{model_repo}'...") api.create_repo(repo_id=model_repo, private=False, token=hf_token) logging.info(f"Repository '{model_repo}' created successfully.") except Exception as e: logging.warning(f"Repository might already exist: {str(e)}") try: model.push_to_hub(model_repo, use_auth_token=hf_token) tokenizer.push_to_hub(model_repo, use_auth_token=hf_token) logging.info(f"Model and tokenizer pushed to Hugging Face Hub at '{model_repo}'.") except Exception as e: logging.error(f"Error pushing model to Hugging Face Hub: {str(e)}") raise e logging.info("Training script finished successfully.") if __name__ == "__main__": main()