In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import onnx
import onnxruntime
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load the ONNX model
ort_session = onnxruntime.InferenceSession("ner_model.onnx")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("./results/best_model")

id2label = {0: "O", 1: "B-SERVICE", 2: "I-SERVICE", 3: "B-LOCATION", 4: "I-LOCATION"}

# Define service mapping
service_mapping = {
    "hotel": ["hotel", "hotels", "khách sạn", "khach san", "ks"],
    "flight": ["flight", "flights", "vé máy bay", "máy bay", "may bay"],
    "car rental": ["car rental", "car rentals", "thuê xe", "xe"],
    "ticket": ["ticket", "tickets", "vé", "vé tham quan", "ve", "ve tham quan"],
    "tour": ["tour", "tours", "du lịch", "du lich"]
}

def map_service(service):
    service = service.lower()
    for key, values in service_mapping.items():
        if any(v in service for v in values):
            return key
    return None

def predict_onnx(text):
    inputs = tokenizer(text, return_tensors="np", truncation=True, padding=True)
    
    # Run inference
    ort_inputs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"]
    }
    ort_outputs = ort_session.run(None, ort_inputs)
    predictions = np.argmax(ort_outputs[0], axis=2)
    predicted_labels = [id2label[p] for p in predictions[0]]
    word_ids = inputs.word_ids()
    aligned_labels = []
    current_word = None
    for word_id, label in zip(word_ids, predicted_labels):
        if word_id != current_word:
            aligned_labels.append(label)
            current_word = word_id
    
    # Extract entities
    entities = {"SERVICE": [], "LOCATION": []}
    current_entity = None
    current_tokens = []
    
    words = text.split()
    for word, label in zip(words, aligned_labels):
        if label.startswith("B-"):
            if current_entity:
                if current_entity == "SERVICE":
                    mapped_service = map_service(" ".join(current_tokens))
                    if mapped_service:
                        entities[current_entity].append(mapped_service)
                else:
                    entities[current_entity].append(" ".join(current_tokens))
            current_entity = label[2:]
            current_tokens = [word]
        elif label.startswith("I-") and current_entity:
            current_tokens.append(word)
        else:
            if current_entity:
                if current_entity == "SERVICE":
                    mapped_service = map_service(" ".join(current_tokens))
                    if mapped_service:
                        entities[current_entity].append(mapped_service)
                else:
                    entities[current_entity].append(" ".join(current_tokens))
                current_entity = None
                current_tokens = []
    
    if current_entity:
        if current_entity == "SERVICE":
            mapped_service = map_service(" ".join(current_tokens))
            if mapped_service:
                entities[current_entity].append(mapped_service)
        else:
            entities[current_entity].append(" ".join(current_tokens))
    
    # Remove duplicates and keep only the first service if multiple are detected
    if entities["SERVICE"]:
        entities["SERVICE"] = [entities["SERVICE"][0]]
    
    return entities

In [8]:
# Test function
def test_ner_onnx(text):
    print(f"Input: {text}")
    result = predict_onnx(text)
    print("Output:", result)
    return result

# Test 
sample_texts = [
    "DAT khách sạn ở Hà Nội",
    "flight to New York",
    "Thuê xe ở Đà Nẵng",
    "Đặt tour du lịch Hội An",
    "I need a ticket for the museum in Paris"
]

for text in sample_texts:
    test_ner_onnx(text)
    print()

Input: DAT khách sạn ở Hà Nội
Output: {'SERVICE': ['hotel'], 'LOCATION': ['Hà Nội']}

Input: flight to New York
Output: {'SERVICE': ['flight'], 'LOCATION': ['York']}

Input: Thuê xe ở Đà Nẵng
Output: {'SERVICE': ['car rental'], 'LOCATION': ['Đà Nẵng']}

Input: Đặt tour du lịch Hội An
Output: {'SERVICE': ['tour'], 'LOCATION': ['Hội An']}

Input: I need a ticket for the museum in Paris
Output: {'SERVICE': ['ticket'], 'LOCATION': ['Paris']}

