portalniy-dev commited on
Commit
8d55ba9
Β·
verified Β·
1 Parent(s): 825d871

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -3,14 +3,14 @@ import torch
3
  from datasets import load_dataset, concatenate_datasets
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
5
 
6
- # Predefined datasets
7
- dataset_names = [
8
- 'imdb',
9
- 'ag_news',
10
- 'squad',
11
- 'cnn_dailymail',
12
- 'wiki40b'
13
- ]
14
 
15
  # Global variables for model and tokenizer
16
  model = None
@@ -18,7 +18,9 @@ tokenizer = None
18
 
19
  # Function to load and prepare datasets
20
  def load_and_prepare_datasets():
21
- datasets = [load_dataset(name) for name in dataset_names]
 
 
22
 
23
  # Concatenate train datasets only for training
24
  train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds])
@@ -109,4 +111,4 @@ with gr.Blocks() as demo:
109
  generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
110
 
111
  # Launch the app with share=True to create a public link
112
- demo.launch(share=True)
 
3
  from datasets import load_dataset, concatenate_datasets
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
5
 
6
+ # Predefined datasets with their configurations
7
+ dataset_names = {
8
+ 'imdb': None,
9
+ 'ag_news': None,
10
+ 'squad': None,
11
+ 'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail
12
+ 'wiki40b': None
13
+ }
14
 
15
  # Global variables for model and tokenizer
16
  model = None
 
18
 
19
  # Function to load and prepare datasets
20
  def load_and_prepare_datasets():
21
+ datasets = []
22
+ for name, config in dataset_names.items():
23
+ datasets.append(load_dataset(name, config))
24
 
25
  # Concatenate train datasets only for training
26
  train_dataset = concatenate_datasets([ds['train'] for ds in datasets if 'train' in ds])
 
111
  generate_button.click(generate_text, inputs=prompt_input, outputs=generated_output)
112
 
113
  # Launch the app with share=True to create a public link
114
+ demo.launch(share=True)