from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset from transformers import TrainingArguments, Trainer # Load LLAMA3 8B model tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") # Load datasets python_codes_dataset = load_dataset('flytech/python-codes-25k', split='train') streamlit_issues_dataset = load_dataset("andfanilo/streamlit-issues") streamlit_docs_dataset = load_dataset("sai-lohith/streamlit_docs") # Combine datasets combined_dataset = python_codes_dataset['text'] + streamlit_issues_dataset['text'] + streamlit_docs_dataset['text'] # Define training arguments training_args = TrainingArguments( per_device_train_batch_size=2, num_train_epochs=3, logging_dir='./logs', output_dir='./output', overwrite_output_dir=True, report_to="none" # Disable logging to avoid cluttering output ) # Define training function def tokenize_function(examples): return tokenizer(examples["text"]) def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // tokenizer.max_len) * tokenizer.max_len # Split by chunks of max_len. result = { k: [t[i : i + tokenizer.max_len] for i in range(0, total_length, tokenizer.max_len)] for k, t in concatenated_examples.items() } return result # Tokenize dataset tokenized_datasets = combined_dataset.map(tokenize_function, batched=True, num_proc=4) # Group texts into chunks of max_len tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=4, ) # Train the model trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets, tokenizer=tokenizer, ) trainer.train() # Save the trained model trainer.save_model("PyStreamlitGPT")