File size: 4,814 Bytes
64cafec
 
 
 
 
8d55ba9
 
 
 
 
 
f6e4f14
8d55ba9
64cafec
742077b
 
 
 
64cafec
 
8d55ba9
 
37f0cbb
 
 
 
 
 
 
72c1ae2
 
 
 
 
37f0cbb
 
 
 
 
 
 
72c1ae2
37f0cbb
 
 
 
 
 
 
578e64a
72c1ae2
578e64a
72c1ae2
 
64cafec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578e64a
64cafec
 
 
 
 
 
 
 
 
 
 
 
578e64a
64cafec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742077b
 
825d871
 
 
64cafec
578e64a
 
 
 
64cafec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742077b
b3b478e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

# Predefined datasets with their configurations
dataset_names = {
    'imdb': None,
    'ag_news': None,
    'squad': None,
    'cnn_dailymail': '1.0.0',  # Specify configuration for cnn_dailymail
    'wiki40b': 'ru'  # Specify language for wiki40b
}

# Global variables for model and tokenizer
model = None
tokenizer = None

# Function to load and prepare datasets
def load_and_prepare_datasets():
    datasets = []
    for name, config in dataset_names.items():
        ds = load_dataset(name, config)
        datasets.append(ds)

        # Print dataset features for debugging
        print(f"Dataset: {name}, Features: {ds['train'].features}")

    # Extract only the relevant fields from each dataset for training
    train_datasets = []
    eval_datasets = []
    
    for ds in datasets:
        if 'train' in ds:
            if 'text' in ds['train'].features:
                train_datasets.append(ds['train'].map(lambda x: {'text': x['text']}))
            elif 'content' in ds['train'].features:  # Example for CNN/DailyMail
                train_datasets.append(ds['train'].map(lambda x: {'text': x['content']}))
            else:
                print(f"Warning: No suitable text field found in {ds['train'].features}")
        
        if 'test' in ds:
            if 'text' in ds['test'].features:
                eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']}))
            elif 'content' in ds['test'].features:  # Example for CNN/DailyMail
                eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']}))
            else:
                print(f"Warning: No suitable text field found in {ds['test'].features}")

    # Concatenate train datasets only for training
    train_dataset = concatenate_datasets(train_datasets)
    
    # Concatenate eval datasets only for evaluation
    eval_dataset = concatenate_datasets(eval_datasets)
    
    return train_dataset, eval_dataset

# Function to preprocess data
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

# Function to train the model
def train_model():
    global model, tokenizer
    
    # Load model and tokenizer
    model_name = 'gpt2'  # You can choose another model if desired
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Load and prepare datasets
    train_dataset, eval_dataset = load_and_prepare_datasets()

    # Preprocess the datasets
    train_dataset = train_dataset.map(preprocess_function, batched=True)
    
    # Set training arguments
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        save_steps=1000,
        evaluation_strategy="steps",
        learning_rate=5e-5  # Adjust learning rate if necessary
    )

    # Train the model
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    trainer.train()
    
    return "Model trained successfully!"

# Function to generate text
def generate_text(prompt):
    global tokenizer  # Ensure we use the global tokenizer variable
    
    if tokenizer is None:
        return "Tokenizer not initialized. Please train the model first."
    
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    
    # Adjust generation parameters for better quality output
    output = model.generate(input_ids, max_length=100, temperature=0.7, top_k=50)
    
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return generated_text

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# LLM Training and Text Generation")
    
    with gr.Row():
        with gr.Column():
            train_button = gr.Button("Train Model")
            output_message = gr.Textbox(label="Training Status", interactive=False)
        
        with gr.Column():
            prompt_input = gr.Textbox(label="Enter prompt for text generation")
            generate_button = gr.Button("Generate Text")
            generated_output = gr.Textbox(label="Generated Text", interactive=False)

    # Button actions
    train_button.click(train_model, outputs=output_message)
    generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)

# Launch the app with share=True to create a public link
demo.launch(share=True)