Spaces:
Runtime error
Runtime error
File size: 4,814 Bytes
64cafec 8d55ba9 f6e4f14 8d55ba9 64cafec 742077b 64cafec 8d55ba9 37f0cbb 72c1ae2 37f0cbb 72c1ae2 37f0cbb 578e64a 72c1ae2 578e64a 72c1ae2 64cafec 578e64a 64cafec 578e64a 64cafec 742077b 825d871 64cafec 578e64a 64cafec 742077b b3b478e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# Predefined datasets with their configurations
dataset_names = {
'imdb': None,
'ag_news': None,
'squad': None,
'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail
'wiki40b': 'ru' # Specify language for wiki40b
}
# Global variables for model and tokenizer
model = None
tokenizer = None
# Function to load and prepare datasets
def load_and_prepare_datasets():
datasets = []
for name, config in dataset_names.items():
ds = load_dataset(name, config)
datasets.append(ds)
# Print dataset features for debugging
print(f"Dataset: {name}, Features: {ds['train'].features}")
# Extract only the relevant fields from each dataset for training
train_datasets = []
eval_datasets = []
for ds in datasets:
if 'train' in ds:
if 'text' in ds['train'].features:
train_datasets.append(ds['train'].map(lambda x: {'text': x['text']}))
elif 'content' in ds['train'].features: # Example for CNN/DailyMail
train_datasets.append(ds['train'].map(lambda x: {'text': x['content']}))
else:
print(f"Warning: No suitable text field found in {ds['train'].features}")
if 'test' in ds:
if 'text' in ds['test'].features:
eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']}))
elif 'content' in ds['test'].features: # Example for CNN/DailyMail
eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']}))
else:
print(f"Warning: No suitable text field found in {ds['test'].features}")
# Concatenate train datasets only for training
train_dataset = concatenate_datasets(train_datasets)
# Concatenate eval datasets only for evaluation
eval_dataset = concatenate_datasets(eval_datasets)
return train_dataset, eval_dataset
# Function to preprocess data
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
# Function to train the model
def train_model():
global model, tokenizer
# Load model and tokenizer
model_name = 'gpt2' # You can choose another model if desired
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load and prepare datasets
train_dataset, eval_dataset = load_and_prepare_datasets()
# Preprocess the datasets
train_dataset = train_dataset.map(preprocess_function, batched=True)
# Set training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_steps=1000,
evaluation_strategy="steps",
learning_rate=5e-5 # Adjust learning rate if necessary
)
# Train the model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
return "Model trained successfully!"
# Function to generate text
def generate_text(prompt):
global tokenizer # Ensure we use the global tokenizer variable
if tokenizer is None:
return "Tokenizer not initialized. Please train the model first."
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Adjust generation parameters for better quality output
output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# LLM Training and Text Generation")
with gr.Row():
with gr.Column():
train_button = gr.Button("Train Model")
output_message = gr.Textbox(label="Training Status", interactive=False)
with gr.Column():
prompt_input = gr.Textbox(label="Enter prompt for text generation")
generate_button = gr.Button("Generate Text")
generated_output = gr.Textbox(label="Generated Text", interactive=False)
# Button actions
train_button.click(train_model, outputs=output_message)
generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
# Launch the app with share=True to create a public link
demo.launch(share=True)
|