Spaces:
Sleeping
Sleeping
Chidam Gopal
commited on
Commit
•
788c760
1
Parent(s):
cd5d153
NER model updates for infer location
Browse files- infer_location.py +60 -39
infer_location.py
CHANGED
@@ -22,6 +22,20 @@ class LocationFinder:
|
|
22 |
# Load the ONNX model
|
23 |
self.ort_session = ort.InferenceSession(model_path)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def find_location(self, sequence, verbose=False):
|
26 |
inputs = self.tokenizer(sequence,
|
27 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
@@ -44,61 +58,68 @@ class LocationFinder:
|
|
44 |
# Define the threshold for NER probability
|
45 |
threshold = 0.6
|
46 |
|
|
|
47 |
label_map = {
|
48 |
0: "O", # Outside any named entity
|
49 |
1: "B-PER", # Beginning of a person entity
|
50 |
2: "I-PER", # Inside a person entity
|
51 |
3: "B-ORG", # Beginning of an organization entity
|
52 |
4: "I-ORG", # Inside an organization entity
|
53 |
-
5: "B-
|
54 |
-
6: "I-
|
55 |
-
7: "B-
|
56 |
-
8: "I-
|
|
|
|
|
57 |
}
|
58 |
|
59 |
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
60 |
|
61 |
-
# List to hold the detected
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
67 |
-
label = label_map[predicted_id]
|
68 |
-
|
69 |
-
# Ignore special tokens like [CLS], [SEP]
|
70 |
-
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
71 |
-
continue
|
72 |
-
|
73 |
-
# Only consider tokens with probability above the threshold
|
74 |
if prob > threshold:
|
75 |
-
|
76 |
-
|
77 |
-
if label == "B-LOC":
|
78 |
-
# If we encounter a B-LOC, we may need to store the previous location
|
79 |
-
if current_location:
|
80 |
-
location_entities.append(" ".join(current_location).replace("##", ""))
|
81 |
-
# Start a new location entity
|
82 |
-
current_location = [token]
|
83 |
-
elif label == "I-LOC" and current_location:
|
84 |
-
# Continue appending to the current location entity
|
85 |
-
current_location.append(token)
|
86 |
else:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
if
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
if __name__ == '__main__':
|
101 |
-
query = "weather in
|
102 |
loc_finder = LocationFinder()
|
103 |
-
|
104 |
-
print(f"query = {query} => {
|
|
|
22 |
# Load the ONNX model
|
23 |
self.ort_session = ort.InferenceSession(model_path)
|
24 |
|
25 |
+
# State abbreviations list for post-processing
|
26 |
+
self.state_abbr = {
|
27 |
+
"AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY",
|
28 |
+
"LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND",
|
29 |
+
"OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
|
30 |
+
}
|
31 |
+
|
32 |
+
# # Helper function to correct misclassified state abbreviations
|
33 |
+
# def correct_state_abbreviation(self, token, predicted_label):
|
34 |
+
# if token.upper() in self.state_abbr and predicted_label == "I-CITY":
|
35 |
+
# return "I-STATE"
|
36 |
+
# return predicted_label
|
37 |
+
|
38 |
+
|
39 |
def find_location(self, sequence, verbose=False):
|
40 |
inputs = self.tokenizer(sequence,
|
41 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
|
|
58 |
# Define the threshold for NER probability
|
59 |
threshold = 0.6
|
60 |
|
61 |
+
# Define the label map for city, state, organization, citystate
|
62 |
label_map = {
|
63 |
0: "O", # Outside any named entity
|
64 |
1: "B-PER", # Beginning of a person entity
|
65 |
2: "I-PER", # Inside a person entity
|
66 |
3: "B-ORG", # Beginning of an organization entity
|
67 |
4: "I-ORG", # Inside an organization entity
|
68 |
+
5: "B-CITY", # Beginning of a city entity
|
69 |
+
6: "I-CITY", # Inside a city entity
|
70 |
+
7: "B-STATE", # Beginning of a state entity
|
71 |
+
8: "I-STATE", # Inside a state entity
|
72 |
+
9: "B-CITYSTATE", # Beginning of a city_state entity
|
73 |
+
10: "I-CITYSTATE", # Inside a city_state entity
|
74 |
}
|
75 |
|
76 |
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
77 |
|
78 |
+
# List to hold the detected entities (city, state, organization, citystate)
|
79 |
+
city_entities = []
|
80 |
+
state_entities = []
|
81 |
+
org_entities = []
|
82 |
+
city_state_entities = []
|
83 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if prob > threshold:
|
85 |
+
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
86 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
+
if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
|
89 |
+
city_entities.append(token.replace("##", ""))
|
90 |
+
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
|
91 |
+
state_entities.append(token.replace("##", ""))
|
92 |
+
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
|
93 |
+
city_state_entities.append(token.replace("##", ""))
|
94 |
+
elif label_map[predicted_id] in ["B-ORG", "I-ORG"]:
|
95 |
+
org_entities.append(token.replace("##", ""))
|
96 |
|
97 |
+
city_state_res = "".join(cs_entity.replace(",", ", ") for cs_entity in city_state_entities) if city_state_entities else None
|
98 |
+
if city_entities:
|
99 |
+
city_res = " ".join(city_entities)
|
100 |
+
elif city_state_res:
|
101 |
+
city_res = city_state_res.split(", ")[0]
|
102 |
+
else:
|
103 |
+
city_res = None
|
104 |
|
105 |
+
if state_entities:
|
106 |
+
state_res = " ".join(state_entities)
|
107 |
+
elif city_state_res and len(city_state_res) > 0:
|
108 |
+
state_res = city_state_res.split(", ")[-1]
|
109 |
+
else:
|
110 |
+
state_res = None
|
111 |
|
112 |
+
org_res = " ".join(org_entities) if org_entities else None
|
113 |
+
|
114 |
+
# Return the detected entities
|
115 |
+
return {
|
116 |
+
'city': city_res,
|
117 |
+
'state': state_res,
|
118 |
+
'organization': org_res,
|
119 |
+
}
|
120 |
|
121 |
if __name__ == '__main__':
|
122 |
+
query = "weather in san francisco, ca"
|
123 |
loc_finder = LocationFinder()
|
124 |
+
entities = loc_finder.find_location(query)
|
125 |
+
print(f"query = {query} => {entities}")
|