intent_classifier / infer_location.py
Chidam Gopal
version update for intent app
1c56cce unverified
raw
history blame
5.19 kB
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import requests
import os
VERSION = "v0.1.1"
class LocationFinder:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
model_url = f"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/{VERSION}/onnx/model_quantized.onnx"
model_dir_path = "models"
model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
if not os.path.exists(model_dir_path):
os.makedirs(model_dir_path)
if not os.path.exists(model_path):
print("Downloading ONNX model...")
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
print("ONNX model downloaded.")
# Load the ONNX model
self.ort_session = ort.InferenceSession(model_path)
def find_location(self, sequence, verbose=False):
inputs = self.tokenizer(sequence,
return_tensors="np", # ONNX requires inputs in NumPy format
padding="max_length", # Pad to max length
truncation=True, # Truncate if the text is too long
max_length=64)
input_feed = {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64),
}
# Run inference with the ONNX model
outputs = self.ort_session.run(None, input_feed)
logits = outputs[0] # Assuming the model output is logits
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
predicted_ids = np.argmax(logits, axis=-1)
predicted_probs = np.max(probabilities, axis=-1)
# Define the threshold for NER probability
threshold = 0.6
# Define the label map for city, state, citystate, etc.
label_map = {
0: "O", # Outside any named entity
1: "B-PER", # Beginning of a person entity
2: "I-PER", # Inside a person entity
3: "B-ORG", # Beginning of an organization entity
4: "I-ORG", # Inside an organization entity
5: "B-CITY", # Beginning of a city entity
6: "I-CITY", # Inside a city entity
7: "B-STATE", # Beginning of a state entity
8: "I-STATE", # Inside a state entity
9: "B-CITYSTATE", # Beginning of a city_state entity
10: "I-CITYSTATE", # Inside a city_state entity
}
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Initialize lists to hold detected entities
city_entities = []
state_entities = []
city_state_entities = []
for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
if prob > threshold:
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
# Handle the case of continuation tokens (like "##" in subwords)
if token.startswith("##") and city_entities:
city_entities[-1] += token[2:] # Remove "##" and append to the last token
else:
city_entities.append(token)
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
if token.startswith("##") and state_entities:
state_entities[-1] += token[2:]
else:
state_entities.append(token)
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
if token.startswith("##") and city_state_entities:
city_state_entities[-1] += token[2:]
else:
city_state_entities.append(token)
# Combine city_state entities and split into city and state if necessary
if city_state_entities:
city_state_str = " ".join(city_state_entities)
city_state_split = city_state_str.split(",") # Split on comma to separate city and state
city_res = city_state_split[0].strip() if city_state_split[0] else None
state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
else:
# If no city_state entities, use detected city and state entities separately
city_res = " ".join(city_entities).strip() if city_entities else None
state_res = " ".join(state_entities).strip() if state_entities else None
# Return the detected city and state as separate components
return {
'city': city_res,
'state': state_res
}
if __name__ == '__main__':
query = "weather in san francisco, ca"
loc_finder = LocationFinder()
entities = loc_finder.find_location(query)
print(f"query = {query} => {entities}")