Chidam Gopal commited on
Commit
624162d
1 Parent(s): 788c760

included state and city in NER

Browse files
Files changed (1) hide show
  1. infer_location.py +5 -15
infer_location.py CHANGED
@@ -22,20 +22,6 @@ class LocationFinder:
22
  # Load the ONNX model
23
  self.ort_session = ort.InferenceSession(model_path)
24
 
25
- # State abbreviations list for post-processing
26
- self.state_abbr = {
27
- "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY",
28
- "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND",
29
- "OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
30
- }
31
-
32
- # # Helper function to correct misclassified state abbreviations
33
- # def correct_state_abbreviation(self, token, predicted_label):
34
- # if token.upper() in self.state_abbr and predicted_label == "I-CITY":
35
- # return "I-STATE"
36
- # return predicted_label
37
-
38
-
39
  def find_location(self, sequence, verbose=False):
40
  inputs = self.tokenizer(sequence,
41
  return_tensors="np", # ONNX requires inputs in NumPy format
@@ -80,6 +66,11 @@ class LocationFinder:
80
  state_entities = []
81
  org_entities = []
82
  city_state_entities = []
 
 
 
 
 
83
  for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
84
  if prob > threshold:
85
  if token in ["[CLS]", "[SEP]", "[PAD]"]:
@@ -115,7 +106,6 @@ class LocationFinder:
115
  return {
116
  'city': city_res,
117
  'state': state_res,
118
- 'organization': org_res,
119
  }
120
 
121
  if __name__ == '__main__':
 
22
  # Load the ONNX model
23
  self.ort_session = ort.InferenceSession(model_path)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def find_location(self, sequence, verbose=False):
26
  inputs = self.tokenizer(sequence,
27
  return_tensors="np", # ONNX requires inputs in NumPy format
 
66
  state_entities = []
67
  org_entities = []
68
  city_state_entities = []
69
+
70
+ city_entities = []
71
+ state_entities = []
72
+ city_state_entities = []
73
+ org_entities = []
74
  for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
75
  if prob > threshold:
76
  if token in ["[CLS]", "[SEP]", "[PAD]"]:
 
106
  return {
107
  'city': city_res,
108
  'state': state_res,
 
109
  }
110
 
111
  if __name__ == '__main__':