portalniy-dev commited on
Commit
37f0cbb
Β·
verified Β·
1 Parent(s): 72c1ae2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -9,7 +9,7 @@ dataset_names = {
9
  'ag_news': None,
10
  'squad': None,
11
  'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail
12
- 'wiki40b': 'ru' # Specify language for wiki40b
13
  }
14
 
15
  # Global variables for model and tokenizer
@@ -20,18 +20,33 @@ tokenizer = None
20
  def load_and_prepare_datasets():
21
  datasets = []
22
  for name, config in dataset_names.items():
23
- datasets.append(load_dataset(name, config))
24
-
25
- # Extract only the 'text' field from each dataset for training
 
 
 
 
26
  train_datasets = []
27
  eval_datasets = []
28
 
29
  for ds in datasets:
30
  if 'train' in ds:
31
- train_datasets.append(ds['train'].map(lambda x: {'text': x['text']}))
 
 
 
 
 
 
32
  if 'test' in ds:
33
- eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']}))
34
-
 
 
 
 
 
35
  # Concatenate train datasets only for training
36
  train_dataset = concatenate_datasets(train_datasets)
37
 
 
9
  'ag_news': None,
10
  'squad': None,
11
  'cnn_dailymail': '1.0.0', # Specify configuration for cnn_dailymail
12
+ 'wiki40b': 'en' # Specify language for wiki40b
13
  }
14
 
15
  # Global variables for model and tokenizer
 
20
  def load_and_prepare_datasets():
21
  datasets = []
22
  for name, config in dataset_names.items():
23
+ ds = load_dataset(name, config)
24
+ datasets.append(ds)
25
+
26
+ # Print dataset features for debugging
27
+ print(f"Dataset: {name}, Features: {ds['train'].features}")
28
+
29
+ # Extract only the relevant fields from each dataset for training
30
  train_datasets = []
31
  eval_datasets = []
32
 
33
  for ds in datasets:
34
  if 'train' in ds:
35
+ if 'text' in ds['train'].features:
36
+ train_datasets.append(ds['train'].map(lambda x: {'text': x['text']}))
37
+ elif 'content' in ds['train'].features: # Example for CNN/DailyMail
38
+ train_datasets.append(ds['train'].map(lambda x: {'text': x['content']}))
39
+ else:
40
+ print(f"Warning: No suitable text field found in {ds['train'].features}")
41
+
42
  if 'test' in ds:
43
+ if 'text' in ds['test'].features:
44
+ eval_datasets.append(ds['test'].map(lambda x: {'text': x['text']}))
45
+ elif 'content' in ds['test'].features: # Example for CNN/DailyMail
46
+ eval_datasets.append(ds['test'].map(lambda x: {'text': x['content']}))
47
+ else:
48
+ print(f"Warning: No suitable text field found in {ds['test'].features}")
49
+
50
  # Concatenate train datasets only for training
51
  train_dataset = concatenate_datasets(train_datasets)
52