detrina-grad / app.py
portalniy-dev's picture
Update app.py
f6e4f14 verified
raw
history blame
4.81 kB
import gradio as gr
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# Predefined datasets with their configurations
dataset_names = {
'imdb': None,
'ag_news': None,
'squad': None,
'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail
'wiki40b': 'ru' # Specify language for wiki40b
}
# Global variables for model and tokenizer
model = None
tokenizer = None
# Function to load and prepare datasets
def load_and_prepare_datasets():
datasets = []
for name, config in dataset_names.items():
ds = load_dataset(name, config)
datasets.append(ds)
# Print dataset features for debugging
print(f"Dataset: {name}, Features: {ds['train'].features}")
# Extract only the relevant fields from each dataset for training
train_datasets = []
eval_datasets = []
for ds in datasets:
if 'train' in ds:
if 'text' in ds['train'].features:
train_datasets.append(ds['train'].map(lambda x: {'text': x['text']}))
elif 'content' in ds['train'].features: # Example for CNN/DailyMail
train_datasets.append(ds['train'].map(lambda x: {'text': x['content']}))
else:
print(f"Warning: No suitable text field found in {ds['train'].features}")
if 'test' in ds:
if 'text' in ds['test'].features:
eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']}))
elif 'content' in ds['test'].features: # Example for CNN/DailyMail
eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']}))
else:
print(f"Warning: No suitable text field found in {ds['test'].features}")
# Concatenate train datasets only for training
train_dataset = concatenate_datasets(train_datasets)
# Concatenate eval datasets only for evaluation
eval_dataset = concatenate_datasets(eval_datasets)
return train_dataset, eval_dataset
# Function to preprocess data
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
# Function to train the model
def train_model():
global model, tokenizer
# Load model and tokenizer
model_name = 'gpt2' # You can choose another model if desired
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load and prepare datasets
train_dataset, eval_dataset = load_and_prepare_datasets()
# Preprocess the datasets
train_dataset = train_dataset.map(preprocess_function, batched=True)
# Set training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_steps=1000,
evaluation_strategy="steps",
learning_rate=5e-5 # Adjust learning rate if necessary
)
# Train the model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
return "Model trained successfully!"
# Function to generate text
def generate_text(prompt):
global tokenizer # Ensure we use the global tokenizer variable
if tokenizer is None:
return "Tokenizer not initialized. Please train the model first."
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Adjust generation parameters for better quality output
output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# LLM Training and Text Generation")
with gr.Row():
with gr.Column():
train_button = gr.Button("Train Model")
output_message = gr.Textbox(label="Training Status", interactive=False)
with gr.Column():
prompt_input = gr.Textbox(label="Enter prompt for text generation")
generate_button = gr.Button("Generate Text")
generated_output = gr.Textbox(label="Generated Text", interactive=False)
# Button actions
train_button.click(train_model, outputs=output_message)
generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
# Launch the app with share=True to create a public link
demo.launch(share=True)