portalniy-dev commited on
Commit
64cafec
Β·
verified Β·
1 Parent(s): 289b970

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from datasets import load_dataset, concatenate_datasets
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
5
+
6
+ # Predefined datasets
7
+ dataset_names = [
8
+ 'imdb',
9
+ 'ag_news',
10
+ 'squad',
11
+ 'cnn_dailymail',
12
+ 'wiki40b'
13
+ ]
14
+
15
+ # Function to load and prepare datasets
16
+ def load_and_prepare_datasets():
17
+ datasets = [load_dataset(name) for name in dataset_names]
18
+
19
+ # Concatenate train and validation datasets
20
+ train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds])
21
+ eval_dataset = concatenate_datasets([ds['validation'] for ds in datasets if 'validation' in ds])
22
+
23
+ return train_dataset, eval_dataset
24
+
25
+ # Function to preprocess data
26
+ def preprocess_function(examples):
27
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
28
+
29
+ # Function to train the model
30
+ def train_model():
31
+ global model, tokenizer
32
+
33
+ # Load model and tokenizer
34
+ model_name = 'gpt2' # You can choose another model if desired
35
+ model = AutoModelForCausalLM.from_pretrained(model_name)
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+
38
+ # Load and prepare datasets
39
+ train_dataset, eval_dataset = load_and_prepare_datasets()
40
+
41
+ # Preprocess the datasets
42
+ train_dataset = train_dataset.map(preprocess_function, batched=True)
43
+ eval_dataset = eval_dataset.map(preprocess_function, batched=True)
44
+
45
+ # Set training arguments
46
+ training_args = TrainingArguments(
47
+ output_dir='./results',
48
+ num_train_epochs=3,
49
+ per_device_train_batch_size=4,
50
+ per_device_eval_batch_size=4,
51
+ warmup_steps=500,
52
+ weight_decay=0.01,
53
+ logging_dir='./logs',
54
+ logging_steps=10,
55
+ save_steps=1000,
56
+ evaluation_strategy="steps",
57
+ )
58
+
59
+ # Train the model
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=train_dataset,
64
+ eval_dataset=eval_dataset,
65
+ )
66
+
67
+ trainer.train()
68
+
69
+ return "Model trained successfully!"
70
+
71
+ # Function to generate text
72
+ def generate_text(prompt):
73
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
74
+ output = model.generate(input_ids, max_length=100)
75
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
76
+
77
+ return generated_text
78
+
79
+ # Gradio interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# LLM Training and Text Generation")
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ train_button = gr.Button("Train Model")
86
+ output_message = gr.Textbox(label="Training Status", interactive=False)
87
+
88
+ with gr.Column():
89
+ prompt_input = gr.Textbox(label="Enter prompt for text generation")
90
+ generate_button = gr.Button("Generate Text")
91
+ generated_output = gr.Textbox(label="Generated Text", interactive=False)
92
+
93
+ # Button actions
94
+ train_button.click(train_model, outputs=output_message)
95
+ generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
96
+
97
+ # Launch the app
98
+ demo.launch()