omri374 commited on
Commit
c37c05e
·
1 Parent(s): 90730f5

Upload transformers_recognizer.py

Browse files
Files changed (1) hide show
  1. transformers_recognizer.py +245 -0
transformers_recognizer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, List, Tuple, Set
3
+
4
+ from presidio_analyzer import (
5
+ RecognizerResult,
6
+ EntityRecognizer,
7
+ AnalysisExplanation,
8
+ )
9
+ from presidio_analyzer.nlp_engine import NlpArtifacts
10
+
11
+ logger = logging.getLogger("presidio-analyzer")
12
+
13
+ try:
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForTokenClassification,
17
+ pipeline,
18
+ models,
19
+ )
20
+ from transformers.models.bert.modeling_bert import BertForTokenClassification
21
+ except ImportError:
22
+ logger.error("transformers is not installed")
23
+
24
+
25
+
26
+ class TransformersRecognizer(EntityRecognizer):
27
+ """
28
+ Wrapper for a transformers model, if needed to be used within Presidio Analyzer.
29
+
30
+ :example:
31
+ >from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
32
+
33
+ >transformers_recognizer = TransformersRecognizer()
34
+
35
+ >registry = RecognizerRegistry()
36
+ >registry.add_recognizer(transformers_recognizer)
37
+
38
+ >analyzer = AnalyzerEngine(registry=registry)
39
+
40
+ >results = analyzer.analyze(
41
+ > "My name is Christopher and I live in Irbid.",
42
+ > language="en",
43
+ > return_decision_process=True,
44
+ >)
45
+ >for result in results:
46
+ > print(result)
47
+ > print(result.analysis_explanation)
48
+
49
+
50
+ """
51
+
52
+ ENTITIES = [
53
+ "LOCATION",
54
+ "PERSON",
55
+ "ORGANIZATION",
56
+ "AGE",
57
+ "ID",
58
+ "PHONE",
59
+ "EMAIL",
60
+ "DATE",
61
+
62
+ ]
63
+
64
+ DEFAULT_EXPLANATION = "Identified as {} by transformers's Named Entity Recognition"
65
+
66
+ CHECK_LABEL_GROUPS = [
67
+ ({"LOCATION"}, {"LOC", "HOSP"}),
68
+ ({"PERSON"}, {"PER", "PERSON", "STAFF","PATIENT"}),
69
+ ({"ORGANIZATION"}, {"ORGANIZATION", "ORG", "PATORG"}),
70
+ ({"AGE"}, {"AGE"}),
71
+ ({"ID"}, {"ID"}),
72
+ ({"EMAIL"}, {"EMAIL"}),
73
+ ({"DATE"}, {"DATE"}),
74
+
75
+ ]
76
+
77
+ PRESIDIO_EQUIVALENCES = {
78
+ "PER": "PERSON",
79
+ "LOC": "LOCATION",
80
+ "ORG": "ORGANIZATION",
81
+ "AGE": "AGE",
82
+ "ID": "ID",
83
+ "EMAIL": "EMAIL"
84
+ }
85
+
86
+ DEFAULT_MODEL_PATH = "obi/deid_roberta_i2b2"
87
+
88
+ def __init__(
89
+ self,
90
+ supported_entities: Optional[List[str]] = None,
91
+ check_label_groups: Optional[Tuple[Set, Set]] = None,
92
+ model: Optional[BertForTokenClassification] = None,
93
+ model_path: Optional[str] = None,
94
+ ):
95
+ if not model and not model_path:
96
+ model_path = self.DEFAULT_MODEL_PATH
97
+ logger.warning(
98
+ f"Both 'model' and 'model_path' arguments are None. Using default model_path={model_path}"
99
+ )
100
+
101
+ if model and model_path:
102
+ logger.warning(
103
+ f"Both 'model' and 'model_path' arguments were provided. Ignoring the model_path"
104
+ )
105
+
106
+ self.check_label_groups = (
107
+ check_label_groups if check_label_groups else self.CHECK_LABEL_GROUPS
108
+ )
109
+
110
+ supported_entities = supported_entities if supported_entities else self.ENTITIES
111
+ self.model = (
112
+ model
113
+ if model
114
+ else pipeline(
115
+ "ner",
116
+ model=AutoModelForTokenClassification.from_pretrained(model_path),
117
+ tokenizer=AutoTokenizer.from_pretrained(model_path),
118
+ aggregation_strategy="simple",
119
+ )
120
+ )
121
+
122
+ super().__init__(
123
+ supported_entities=supported_entities, name="transformers Analytics",
124
+ )
125
+
126
+ def load(self) -> None:
127
+ """Load the model, not used. Model is loaded during initialization."""
128
+ pass
129
+
130
+ def get_supported_entities(self) -> List[str]:
131
+ """
132
+ Return supported entities by this model.
133
+
134
+ :return: List of the supported entities.
135
+ """
136
+ return self.supported_entities
137
+
138
+ # Class to use transformers with Presidio as an external recognizer.
139
+ def analyze(
140
+ self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts = None
141
+ ) -> List[RecognizerResult]:
142
+ """
143
+ Analyze text using Text Analytics.
144
+
145
+ :param text: The text for analysis.
146
+ :param entities: Not working properly for this recognizer.
147
+ :param nlp_artifacts: Not used by this recognizer.
148
+ :return: The list of Presidio RecognizerResult constructed from the recognized
149
+ transformers detections.
150
+ """
151
+
152
+ results = []
153
+ ner_results = self.model(text)
154
+
155
+ # If there are no specific list of entities, we will look for all of it.
156
+ if not entities:
157
+ entities = self.supported_entities
158
+
159
+ for entity in entities:
160
+ if entity not in self.supported_entities:
161
+ continue
162
+
163
+ for res in ner_results:
164
+ if not self.__check_label(
165
+ entity, res["entity_group"], self.check_label_groups
166
+ ):
167
+ continue
168
+ textual_explanation = self.DEFAULT_EXPLANATION.format(
169
+ res["entity_group"]
170
+ )
171
+ explanation = self.build_transformers_explanation(
172
+ round(res["score"], 2), textual_explanation
173
+ )
174
+ transformers_result = self._convert_to_recognizer_result(
175
+ res, explanation
176
+ )
177
+
178
+ results.append(transformers_result)
179
+
180
+ return results
181
+
182
+ def _convert_to_recognizer_result(self, res, explanation) -> RecognizerResult:
183
+
184
+ entity_type = self.PRESIDIO_EQUIVALENCES.get(
185
+ res["entity_group"], res["entity_group"]
186
+ )
187
+ transformers_score = round(res["score"], 2)
188
+
189
+ transformers_results = RecognizerResult(
190
+ entity_type=entity_type,
191
+ start=res["start"],
192
+ end=res["end"],
193
+ score=transformers_score,
194
+ analysis_explanation=explanation,
195
+ )
196
+
197
+ return transformers_results
198
+
199
+ def build_transformers_explanation(
200
+ self, original_score: float, explanation: str
201
+ ) -> AnalysisExplanation:
202
+ """
203
+ Create explanation for why this result was detected.
204
+
205
+ :param original_score: Score given by this recognizer
206
+ :param explanation: Explanation string
207
+ :return:
208
+ """
209
+ explanation = AnalysisExplanation(
210
+ recognizer=self.__class__.__name__,
211
+ original_score=original_score,
212
+ textual_explanation=explanation,
213
+ )
214
+ return explanation
215
+
216
+ @staticmethod
217
+ def __check_label(
218
+ entity: str, label: str, check_label_groups: Tuple[Set, Set]
219
+ ) -> bool:
220
+ return any(
221
+ [entity in egrp and label in lgrp for egrp, lgrp in check_label_groups]
222
+ )
223
+
224
+
225
+ if __name__ == "__main__":
226
+
227
+ from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
228
+
229
+ transformers_recognizer = (
230
+ TransformersRecognizer()
231
+ ) # This would download a large (~500Mb) model on the first run
232
+
233
+ registry = RecognizerRegistry()
234
+ registry.add_recognizer(transformers_recognizer)
235
+
236
+ analyzer = AnalyzerEngine(registry=registry)
237
+
238
+ results = analyzer.analyze(
239
+ "My name is Christopher and I live in Irbid.",
240
+ language="en",
241
+ return_decision_process=True,
242
+ )
243
+ for result in results:
244
+ print(result)
245
+ print(result.analysis_explanation)