import gradio as gr import torch from datasets import load_dataset, concatenate_datasets from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments # Predefined datasets with their configurations dataset_names = { 'imdb': None, 'ag_news': None, 'squad': None, 'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail 'wiki40b': None } # Global variables for model and tokenizer model = None tokenizer = None # Function to load and prepare datasets def load_and_prepare_datasets(): datasets = [] for name, config in dataset_names.items(): datasets.append(load_dataset(name, config)) # Concatenate train datasets only for training train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds]) # Use only a subset for evaluation if needed eval_dataset = concatenate_datasets([ds['test'] for ds in datasets if 'test' in ds]) return train_dataset, eval_dataset # Function to preprocess data def preprocess_function(examples): return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) # Function to train the model def train_model(): global model, tokenizer # Load model and tokenizer model_name = 'gpt2' # You can choose another model if desired model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load and prepare datasets train_dataset, eval_dataset = load_and_prepare_datasets() # Preprocess the datasets train_dataset = train_dataset.map(preprocess_function, batched=True) # Set training arguments training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, save_steps=1000, evaluation_strategy="steps", learning_rate=5e-5 # Adjust learning rate if necessary ) # Train the model trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) trainer.train() return "Model trained successfully!" # Function to generate text def generate_text(prompt): global tokenizer # Ensure we use the global tokenizer variable if tokenizer is None: return "Tokenizer not initialized. Please train the model first." input_ids = tokenizer.encode(prompt, return_tensors='pt') # Adjust generation parameters for better quality output output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text # Gradio interface with gr.Blocks() as demo: gr.Markdown("# LLM Training and Text Generation") with gr.Row(): with gr.Column(): train_button = gr.Button("Train Model") output_message = gr.Textbox(label="Training Status", interactive=False) with gr.Column(): prompt_input = gr.Textbox(label="Enter prompt for text generation") generate_button = gr.Button("Generate Text") generated_output = gr.Textbox(label="Generated Text", interactive=False) # Button actions train_button.click(train_model, outputs=output_message) generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output) # Launch the app with share=True to create a public link demo.launch(share=True)