Chidam Gopal commited on
Commit
788c760
1 Parent(s): cd5d153

NER model updates for infer location

Browse files
Files changed (1) hide show
  1. infer_location.py +60 -39
infer_location.py CHANGED
@@ -22,6 +22,20 @@ class LocationFinder:
22
  # Load the ONNX model
23
  self.ort_session = ort.InferenceSession(model_path)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def find_location(self, sequence, verbose=False):
26
  inputs = self.tokenizer(sequence,
27
  return_tensors="np", # ONNX requires inputs in NumPy format
@@ -44,61 +58,68 @@ class LocationFinder:
44
  # Define the threshold for NER probability
45
  threshold = 0.6
46
 
 
47
  label_map = {
48
  0: "O", # Outside any named entity
49
  1: "B-PER", # Beginning of a person entity
50
  2: "I-PER", # Inside a person entity
51
  3: "B-ORG", # Beginning of an organization entity
52
  4: "I-ORG", # Inside an organization entity
53
- 5: "B-LOC", # Beginning of a location entity
54
- 6: "I-LOC", # Inside a location entity
55
- 7: "B-MISC", # Beginning of a miscellaneous entity
56
- 8: "I-MISC" # Inside a miscellaneous entity
 
 
57
  }
58
 
59
  tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
60
 
61
- # List to hold the detected location terms
62
- location_entities = []
63
- current_location = []
64
-
65
- # Loop through each token and its predicted label and probability
66
  for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
67
- label = label_map[predicted_id]
68
-
69
- # Ignore special tokens like [CLS], [SEP]
70
- if token in ["[CLS]", "[SEP]", "[PAD]"]:
71
- continue
72
-
73
- # Only consider tokens with probability above the threshold
74
  if prob > threshold:
75
- # If the token is a part of a location entity (B-LOC or I-LOC)
76
- if label in ["B-LOC", "I-LOC"]:
77
- if label == "B-LOC":
78
- # If we encounter a B-LOC, we may need to store the previous location
79
- if current_location:
80
- location_entities.append(" ".join(current_location).replace("##", ""))
81
- # Start a new location entity
82
- current_location = [token]
83
- elif label == "I-LOC" and current_location:
84
- # Continue appending to the current location entity
85
- current_location.append(token)
86
  else:
87
- # If we encounter a non-location entity, store the current location and reset
88
- if current_location:
89
- location_entities.append(" ".join(current_location).replace("##", ""))
90
- current_location = []
 
 
 
 
91
 
92
- # Append the last location entity if it exists
93
- if current_location:
94
- location_entities.append(" ".join(current_location).replace("##", ""))
 
 
 
 
95
 
96
- # Return the detected location terms
97
- return location_entities[0] if location_entities != [] else None
 
 
 
 
98
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == '__main__':
101
- query = "weather in seattle"
102
  loc_finder = LocationFinder()
103
- location = loc_finder.find_location(query)
104
- print(f"query = {query} => {location}")
 
22
  # Load the ONNX model
23
  self.ort_session = ort.InferenceSession(model_path)
24
 
25
+ # State abbreviations list for post-processing
26
+ self.state_abbr = {
27
+ "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY",
28
+ "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND",
29
+ "OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
30
+ }
31
+
32
+ # # Helper function to correct misclassified state abbreviations
33
+ # def correct_state_abbreviation(self, token, predicted_label):
34
+ # if token.upper() in self.state_abbr and predicted_label == "I-CITY":
35
+ # return "I-STATE"
36
+ # return predicted_label
37
+
38
+
39
  def find_location(self, sequence, verbose=False):
40
  inputs = self.tokenizer(sequence,
41
  return_tensors="np", # ONNX requires inputs in NumPy format
 
58
  # Define the threshold for NER probability
59
  threshold = 0.6
60
 
61
+ # Define the label map for city, state, organization, citystate
62
  label_map = {
63
  0: "O", # Outside any named entity
64
  1: "B-PER", # Beginning of a person entity
65
  2: "I-PER", # Inside a person entity
66
  3: "B-ORG", # Beginning of an organization entity
67
  4: "I-ORG", # Inside an organization entity
68
+ 5: "B-CITY", # Beginning of a city entity
69
+ 6: "I-CITY", # Inside a city entity
70
+ 7: "B-STATE", # Beginning of a state entity
71
+ 8: "I-STATE", # Inside a state entity
72
+ 9: "B-CITYSTATE", # Beginning of a city_state entity
73
+ 10: "I-CITYSTATE", # Inside a city_state entity
74
  }
75
 
76
  tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
77
 
78
+ # List to hold the detected entities (city, state, organization, citystate)
79
+ city_entities = []
80
+ state_entities = []
81
+ org_entities = []
82
+ city_state_entities = []
83
  for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
 
 
 
 
 
 
 
84
  if prob > threshold:
85
+ if token in ["[CLS]", "[SEP]", "[PAD]"]:
86
+ continue
 
 
 
 
 
 
 
 
 
87
  else:
88
+ if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
89
+ city_entities.append(token.replace("##", ""))
90
+ elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
91
+ state_entities.append(token.replace("##", ""))
92
+ elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
93
+ city_state_entities.append(token.replace("##", ""))
94
+ elif label_map[predicted_id] in ["B-ORG", "I-ORG"]:
95
+ org_entities.append(token.replace("##", ""))
96
 
97
+ city_state_res = "".join(cs_entity.replace(",", ", ") for cs_entity in city_state_entities) if city_state_entities else None
98
+ if city_entities:
99
+ city_res = " ".join(city_entities)
100
+ elif city_state_res:
101
+ city_res = city_state_res.split(", ")[0]
102
+ else:
103
+ city_res = None
104
 
105
+ if state_entities:
106
+ state_res = " ".join(state_entities)
107
+ elif city_state_res and len(city_state_res) > 0:
108
+ state_res = city_state_res.split(", ")[-1]
109
+ else:
110
+ state_res = None
111
 
112
+ org_res = " ".join(org_entities) if org_entities else None
113
+
114
+ # Return the detected entities
115
+ return {
116
+ 'city': city_res,
117
+ 'state': state_res,
118
+ 'organization': org_res,
119
+ }
120
 
121
  if __name__ == '__main__':
122
+ query = "weather in san francisco, ca"
123
  loc_finder = LocationFinder()
124
+ entities = loc_finder.find_location(query)
125
+ print(f"query = {query} => {entities}")