inoki-giskard commited on
Commit
77961b6
·
1 Parent(s): 85095eb

Move text classification column mapping

Browse files
Files changed (2) hide show
  1. app.py +2 -110
  2. text_classification.py +112 -0
app.py CHANGED
@@ -7,12 +7,11 @@ import time
7
  from pathlib import Path
8
 
9
  import json
10
- import logging
11
-
12
- import pandas as pd
13
 
14
  from transformers.pipelines import TextClassificationPipeline
15
 
 
 
16
 
17
  HF_REPO_ID = 'HF_REPO_ID'
18
  HF_SPACE_ID = 'SPACE_ID'
@@ -61,113 +60,6 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
61
  return dataset_id, dataset_config, dataset_split
62
 
63
 
64
- def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
65
- for model_label in id2label_mapping.keys():
66
- if model_label.upper() == label.upper():
67
- return model_label, label
68
- return None, label
69
-
70
-
71
- def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
72
- id2label_mapping = {id2label[k]: None for k in id2label.keys()}
73
- dataset_labels = None
74
- for feature in dataset_features.values():
75
- if not isinstance(feature, datasets.ClassLabel):
76
- continue
77
- if len(feature.names) != len(id2label_mapping.keys()):
78
- continue
79
-
80
- dataset_labels = feature.names
81
-
82
- # Try to match labels
83
- for label in feature.names:
84
- if label in id2label_mapping.keys():
85
- model_label = label
86
- else:
87
- # Try to find case unsensative
88
- model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
89
- if model_label is not None:
90
- id2label_mapping[model_label] = label
91
-
92
- return id2label_mapping, dataset_labels
93
-
94
-
95
- def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
96
- # We assume dataset is ok here
97
- ds = datasets.load_dataset(d_id, config)[split]
98
-
99
- try:
100
- dataset_features = ds.features
101
- except AttributeError:
102
- # Dataset does not have features, need to provide everything
103
- return None, None, None
104
-
105
- # Check whether we need to infer the text input column
106
- infer_text_input_column = True
107
- if "text" in column_mapping.keys():
108
- dataset_text_column = column_mapping["text"]
109
- if dataset_text_column in dataset_features.keys():
110
- infer_text_input_column = False
111
- else:
112
- logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")
113
-
114
- if infer_text_input_column:
115
- # Try to retrieve one
116
- candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
117
- if len(candidates) > 0:
118
- logging.debug(f"Candidates are {candidates}")
119
- column_mapping["text"] = candidates[0]
120
- else:
121
- # Not found a text feature
122
- return column_mapping, None, None
123
-
124
- # Load dataset as DataFrame
125
- df = ds.to_pandas()
126
-
127
- # Retrieve all labels
128
- id2label_mapping = {}
129
- id2label = ppl.model.config.id2label
130
- label2id = {v: k for k, v in id2label.items()}
131
- prediction_result = None
132
- try:
133
- # Use the first item to test prediction
134
- results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
135
- prediction_result = {
136
- f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
137
- }
138
- except Exception:
139
- # Pipeline prediction failed, need to provide labels
140
- return column_mapping, None, None
141
-
142
- # Infer labels
143
- id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
144
- if "label" in column_mapping.keys():
145
- if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
146
- logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
147
- return column_mapping, prediction_result, None
148
-
149
- if isinstance(column_mapping["label"], dict):
150
- for model_label in id2label_mapping.keys():
151
- id2label_mapping[model_label] = column_mapping["label"][str(label2id[model_label])]
152
- elif None in id2label_mapping.values():
153
- column_mapping["label"] = {
154
- i: None for i in id2label.keys()
155
- }
156
- return column_mapping, prediction_result, None
157
-
158
- id2label_df = pd.DataFrame({
159
- "ID": [i for i in id2label.keys()],
160
- "Model labels": [id2label[label] for label in id2label.keys()],
161
- "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
162
- })
163
- if "label" not in column_mapping.keys():
164
- column_mapping["label"] = {
165
- i: id2label_mapping[id2label[i]] for i in id2label.keys()
166
- }
167
-
168
- return column_mapping, prediction_result, id2label_df
169
-
170
-
171
  def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
172
  # Validate model
173
  m_id, ppl = check_model(model_id=model_id)
 
7
  from pathlib import Path
8
 
9
  import json
 
 
 
10
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
+ from text_classification import text_classification_fix_column_mapping
14
+
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
17
  HF_SPACE_ID = 'SPACE_ID'
 
60
  return dataset_id, dataset_config, dataset_split
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
64
  # Validate model
65
  m_id, ppl = check_model(model_id=model_id)
text_classification.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+
3
+ import logging
4
+
5
+ import pandas as pd
6
+
7
+
8
+ def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
9
+ for model_label in id2label_mapping.keys():
10
+ if model_label.upper() == label.upper():
11
+ return model_label, label
12
+ return None, label
13
+
14
+
15
+ def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
16
+ id2label_mapping = {id2label[k]: None for k in id2label.keys()}
17
+ dataset_labels = None
18
+ for feature in dataset_features.values():
19
+ if not isinstance(feature, datasets.ClassLabel):
20
+ continue
21
+ if len(feature.names) != len(id2label_mapping.keys()):
22
+ continue
23
+
24
+ dataset_labels = feature.names
25
+
26
+ # Try to match labels
27
+ for label in feature.names:
28
+ if label in id2label_mapping.keys():
29
+ model_label = label
30
+ else:
31
+ # Try to find case unsensative
32
+ model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
33
+ if model_label is not None:
34
+ id2label_mapping[model_label] = label
35
+
36
+ return id2label_mapping, dataset_labels
37
+
38
+
39
+ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
40
+ # We assume dataset is ok here
41
+ ds = datasets.load_dataset(d_id, config)[split]
42
+
43
+ try:
44
+ dataset_features = ds.features
45
+ except AttributeError:
46
+ # Dataset does not have features, need to provide everything
47
+ return None, None, None
48
+
49
+ # Check whether we need to infer the text input column
50
+ infer_text_input_column = True
51
+ if "text" in column_mapping.keys():
52
+ dataset_text_column = column_mapping["text"]
53
+ if dataset_text_column in dataset_features.keys():
54
+ infer_text_input_column = False
55
+ else:
56
+ logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")
57
+
58
+ if infer_text_input_column:
59
+ # Try to retrieve one
60
+ candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
61
+ if len(candidates) > 0:
62
+ logging.debug(f"Candidates are {candidates}")
63
+ column_mapping["text"] = candidates[0]
64
+ else:
65
+ # Not found a text feature
66
+ return column_mapping, None, None
67
+
68
+ # Load dataset as DataFrame
69
+ df = ds.to_pandas()
70
+
71
+ # Retrieve all labels
72
+ id2label_mapping = {}
73
+ id2label = ppl.model.config.id2label
74
+ label2id = {v: k for k, v in id2label.items()}
75
+ prediction_result = None
76
+ try:
77
+ # Use the first item to test prediction
78
+ results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
79
+ prediction_result = {
80
+ f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
81
+ }
82
+ except Exception:
83
+ # Pipeline prediction failed, need to provide labels
84
+ return column_mapping, None, None
85
+
86
+ # Infer labels
87
+ id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
88
+ if "label" in column_mapping.keys():
89
+ if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
90
+ logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
91
+ return column_mapping, prediction_result, None
92
+
93
+ if isinstance(column_mapping["label"], dict):
94
+ for model_label in id2label_mapping.keys():
95
+ id2label_mapping[model_label] = column_mapping["label"][str(label2id[model_label])]
96
+ elif None in id2label_mapping.values():
97
+ column_mapping["label"] = {
98
+ i: None for i in id2label.keys()
99
+ }
100
+ return column_mapping, prediction_result, None
101
+
102
+ id2label_df = pd.DataFrame({
103
+ "ID": [i for i in id2label.keys()],
104
+ "Model labels": [id2label[label] for label in id2label.keys()],
105
+ "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
106
+ })
107
+ if "label" not in column_mapping.keys():
108
+ column_mapping["label"] = {
109
+ i: id2label_mapping[id2label[i]] for i in id2label.keys()
110
+ }
111
+
112
+ return column_mapping, prediction_result, id2label_df