import gradio as gr import torch from datasets import load_dataset, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments # Predefined datasets with their configurations dataset_names = { 'imdb': None, 'ag_news': None, 'squad': None, 'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail 'wiki40b': 'en' # Specify language for wiki40b } # Global variables for model and tokenizer model = None tokenizer = None # Function to load and prepare datasets def load_and_prepare_datasets(): datasets = [] for name, config in dataset_names.items(): ds = load_dataset(name, config) datasets.append(ds) # Print dataset features for debugging print(f"Dataset: {name}, Features: {ds['train'].features}") # Extract only the relevant fields from each dataset for training train_datasets = [] eval_datasets = [] for ds in datasets: if 'train' in ds: # Extract text field based on available keys if 'text' in ds['train'].features: train_datasets.append(ds['train'].map(lambda x: {'text': x['text']})) elif 'content' in ds['train'].features: # Example for CNN/DailyMail train_datasets.append(ds['train'].map(lambda x: {'text': x['content']})) else: print(f"Warning: No suitable text field found in {ds['train'].features}") if 'test' in ds: # Extract text field based on available keys if 'text' in ds['test'].features: eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']})) elif 'content' in ds['test'].features: # Example for CNN/DailyMail eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']})) else: print(f"Warning: No suitable text field found in {ds['test'].features}") # Concatenate train datasets only for training train_dataset = concatenate_datasets(train_datasets) # Concatenate eval datasets only for evaluation eval_dataset = concatenate_datasets(eval_datasets) return train_dataset, eval_dataset # Function to preprocess data def preprocess_function(examples): return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) # Function to train the model def train_model(): global model, tokenizer # Load model and tokenizer model_name = 'gpt2' # You can choose another model if desired model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load and prepare datasets train_dataset, eval_dataset = load_and_prepare_datasets() # Preprocess the datasets train_dataset = train_dataset.map(preprocess_function, batched=True) # Set training arguments training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, save_steps=1000, evaluation_strategy="steps", learning_rate=5e-5 # Adjust learning rate if necessary ) # Train the model trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) trainer.train() return "Model trained successfully!" # Function to generate text def generate_text(prompt): global tokenizer # Ensure we use the global tokenizer variable if tokenizer is None: return "Tokenizer not initialized. Please train the model first." input_ids = tokenizer.encode(prompt, return_tensors='pt') # Adjust generation parameters for better quality output output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text # Gradio interface with gr.Blocks() as demo: gr.Markdown("# LLM Training and Text Generation") with gr.Row(): with gr.Column(): train_button = gr.Button("Train Model") output_message = gr.Textbox(label="Training Status", interactive=False) with gr.Column(): prompt_input = gr.Textbox(label="Enter prompt for text generation") generate_button = gr.Button("Generate Text") generated_output = gr.Textbox(label="Generated Text", interactive=False) # Button actions train_button.click(train_model, outputs=output_message) generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output) # Launch the app with share=True to create a public link demo.launch(share=True)