Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from datasets import load_dataset, concatenate_datasets | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
# Predefined datasets with their configurations | |
dataset_names = { | |
'imdb': None, | |
'ag_news': None, | |
'squad': None, | |
'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail | |
'wiki40b': 'ru' # Specify language for wiki40b | |
} | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
# Function to load and prepare datasets | |
def load_and_prepare_datasets(): | |
datasets = [] | |
for name, config in dataset_names.items(): | |
ds = load_dataset(name, config) | |
datasets.append(ds) | |
# Print dataset features for debugging | |
print(f"Dataset: {name}, Features: {ds['train'].features}") | |
# Extract only the relevant fields from each dataset for training | |
train_datasets = [] | |
eval_datasets = [] | |
for ds in datasets: | |
if 'train' in ds: | |
if 'text' in ds['train'].features: | |
train_datasets.append(ds['train'].map(lambda x: {'text': x['text']})) | |
elif 'content' in ds['train'].features: # Example for CNN/DailyMail | |
train_datasets.append(ds['train'].map(lambda x: {'text': x['content']})) | |
else: | |
print(f"Warning: No suitable text field found in {ds['train'].features}") | |
if 'test' in ds: | |
if 'text' in ds['test'].features: | |
eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']})) | |
elif 'content' in ds['test'].features: # Example for CNN/DailyMail | |
eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']})) | |
else: | |
print(f"Warning: No suitable text field found in {ds['test'].features}") | |
# Concatenate train datasets only for training | |
train_dataset = concatenate_datasets(train_datasets) | |
# Concatenate eval datasets only for evaluation | |
eval_dataset = concatenate_datasets(eval_datasets) | |
return train_dataset, eval_dataset | |
# Function to preprocess data | |
def preprocess_function(examples): | |
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) | |
# Function to train the model | |
def train_model(): | |
global model, tokenizer | |
# Load model and tokenizer | |
model_name = 'gpt2' # You can choose another model if desired | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load and prepare datasets | |
train_dataset, eval_dataset = load_and_prepare_datasets() | |
# Preprocess the datasets | |
train_dataset = train_dataset.map(preprocess_function, batched=True) | |
# Set training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
warmup_steps=500, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=10, | |
save_steps=1000, | |
evaluation_strategy="steps", | |
learning_rate=5e-5 # Adjust learning rate if necessary | |
) | |
# Train the model | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
trainer.train() | |
return "Model trained successfully!" | |
# Function to generate text | |
def generate_text(prompt): | |
global tokenizer # Ensure we use the global tokenizer variable | |
if tokenizer is None: | |
return "Tokenizer not initialized. Please train the model first." | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
# Adjust generation parameters for better quality output | |
output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM Training and Text Generation") | |
with gr.Row(): | |
with gr.Column(): | |
train_button = gr.Button("Train Model") | |
output_message = gr.Textbox(label="Training Status", interactive=False) | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Enter prompt for text generation") | |
generate_button = gr.Button("Generate Text") | |
generated_output = gr.Textbox(label="Generated Text", interactive=False) | |
# Button actions | |
train_button.click(train_model, outputs=output_message) | |
generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output) | |
# Launch the app with share=True to create a public link | |
demo.launch(share=True) | |