File size: 4,689 Bytes
f9714b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import requests
import os

class LocationFinder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
        model_url = "https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
        model_dir_path = "models"
        model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
        if not os.path.exists(model_dir_path):
            os.makedirs(model_dir_path)
        if not os.path.exists(model_path):
            print("Downloading ONNX model...")
            response = requests.get(model_url)
            with open(model_path, "wb") as f:
                f.write(response.content)
            print("ONNX model downloaded.")

        # Load the ONNX model
        self.ort_session = ort.InferenceSession(model_path)

    def find_location(self, sequence, verbose=False):
        inputs = self.tokenizer(sequence,
                                return_tensors="np",  # ONNX requires inputs in NumPy format
                                padding="max_length",  # Pad to max length
                                truncation=True,       # Truncate if the text is too long
                                max_length=64)
        input_feed = {
            'input_ids': inputs['input_ids'].astype(np.int64),
            'attention_mask': inputs['attention_mask'].astype(np.int64),
        }
        
        # Run inference with the ONNX model
        outputs = self.ort_session.run(None, input_feed)
        logits = outputs[0]  # Assuming the model output is logits
        probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
        
        predicted_ids = np.argmax(logits, axis=-1)
        predicted_probs = np.max(probabilities, axis=-1)
        
        # Define the threshold for NER probability
        threshold = 0.6
        
        label_map = {
            0: "O",        # Outside any named entity
            1: "B-PER",    # Beginning of a person entity
            2: "I-PER",    # Inside a person entity
            3: "B-ORG",    # Beginning of an organization entity
            4: "I-ORG",    # Inside an organization entity
            5: "B-LOC",    # Beginning of a location entity
            6: "I-LOC",    # Inside a location entity
            7: "B-MISC",   # Beginning of a miscellaneous entity
            8: "I-MISC"    # Inside a miscellaneous entity
        }
        
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # List to hold the detected location terms
        location_entities = []
        current_location = []

        # Loop through each token and its predicted label and probability
        for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
            label = label_map[predicted_id]

            # Ignore special tokens like [CLS], [SEP]
            if token in ["[CLS]", "[SEP]", "[PAD]"]:
                continue
        
            # Only consider tokens with probability above the threshold
            if prob > threshold:
                # If the token is a part of a location entity (B-LOC or I-LOC)
                if label in ["B-LOC", "I-LOC"]:
                    if label == "B-LOC":
                        # If we encounter a B-LOC, we may need to store the previous location
                        if current_location:
                            location_entities.append(" ".join(current_location).replace("##", ""))
                        # Start a new location entity
                        current_location = [token]
                    elif label == "I-LOC" and current_location:
                        # Continue appending to the current location entity
                        current_location.append(token)
                else:
                    # If we encounter a non-location entity, store the current location and reset
                    if current_location:
                        location_entities.append(" ".join(current_location).replace("##", ""))
                        current_location = []
        
        # Append the last location entity if it exists
        if current_location:
            location_entities.append(" ".join(current_location).replace("##", ""))

        # Return the detected location terms
        return location_entities[0] if location_entities != [] else None


if __name__ == '__main__':
    query = "weather in seattle"
    loc_finder = LocationFinder()
    location = loc_finder.find_location(query)
    print(f"query = {query} => {location}")