|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from transformers import BertTokenizer |
|
from ts.torch_handler.base_handler import BaseHandler |
|
from sklearn.preprocessing import OneHotEncoder |
|
|
|
import transformers |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class AttentionPool(nn.Module): |
|
def __init__(self, hidden_size): |
|
super().__init__() |
|
self.attention = nn.Linear(hidden_size, 1) |
|
|
|
def forward(self, last_hidden_state): |
|
attention_scores = self.attention(last_hidden_state).squeeze(-1) |
|
attention_weights = F.softmax(attention_scores, dim=1) |
|
pooled_output = torch.bmm(attention_weights.unsqueeze(1), last_hidden_state).squeeze(1) |
|
return pooled_output |
|
|
|
class MultiSampleDropout(nn.Module): |
|
def __init__(self, dropout=0.5, num_samples=5): |
|
super().__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.num_samples = num_samples |
|
|
|
def forward(self, x): |
|
return torch.mean(torch.stack([self.dropout(x) for _ in range(self.num_samples)]), dim=0) |
|
|
|
|
|
class ImprovedBERTClass(nn.Module): |
|
def __init__(self, num_classes=13): |
|
super().__init__() |
|
self.bert = transformers.BertModel.from_pretrained('bert-base-uncased') |
|
self.attention_pool = AttentionPool(768) |
|
self.dropout = MultiSampleDropout() |
|
self.norm = nn.LayerNorm(768) |
|
self.classifier = nn.Linear(768, num_classes) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
bert_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
pooled_output = self.attention_pool(bert_output.last_hidden_state) |
|
pooled_output = self.dropout(pooled_output) |
|
pooled_output = self.norm(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|
|
|
|
class UICardMappingHandler(BaseHandler): |
|
def __init__(self): |
|
super().__init__() |
|
self.initialized = False |
|
|
|
def initialize(self, context): |
|
self.manifest = context.manifest |
|
properties = context.system_properties |
|
model_dir = properties.get("model_dir") |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
with open(os.path.join(model_dir, 'config.json'), 'r') as f: |
|
self.config = json.load(f) |
|
|
|
|
|
self.labels = ['Videos', 'Unit Conversion', 'Translation', 'Shopping Product Comparison', 'Restaurants', 'Product', 'Information', 'Images', 'Gift', 'General Comparison', 'Flights', 'Answer', 'Aircraft Seat Map'] |
|
labels_np = np.array(self.labels).reshape(-1, 1) |
|
self.encoder = OneHotEncoder(sparse_output=False) |
|
self.encoder.fit(labels_np) |
|
|
|
|
|
self.model = ImprovedBERTClass() |
|
self.model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pth'), map_location=self.device)) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
self.initialized = True |
|
|
|
def preprocess(self, data): |
|
text = data[0].get("body").get("text", "") |
|
k = data[0].get("body").get("k", 3) |
|
|
|
inputs = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=64, |
|
padding='max_length', |
|
return_tensors='pt', |
|
truncation=True |
|
) |
|
|
|
return { |
|
"ids": inputs['input_ids'].to(self.device, dtype=torch.long), |
|
"mask": inputs['attention_mask'].to(self.device, dtype=torch.long), |
|
"token_type_ids": inputs['token_type_ids'].to(self.device, dtype=torch.long), |
|
"k": k |
|
} |
|
|
|
def inference(self, data): |
|
with torch.no_grad(): |
|
outputs = self.model(data["ids"], data["mask"], data["token_type_ids"]) |
|
probabilities = torch.sigmoid(outputs) |
|
return probabilities.cpu().detach().numpy().flatten(), data["k"] |
|
|
|
def postprocess(self, inference_output): |
|
probabilities, k = inference_output |
|
|
|
|
|
top_k_indices = np.argsort(probabilities)[-k:][::-1] |
|
top_k_probs = probabilities[top_k_indices] |
|
|
|
|
|
top_k_one_hot = np.zeros((k, len(probabilities))) |
|
for i, idx in enumerate(top_k_indices): |
|
top_k_one_hot[i, idx] = 1 |
|
|
|
|
|
top_k_cards = [self.decode_vector(one_hot.reshape(1, -1)) for one_hot in top_k_one_hot] |
|
|
|
|
|
top_k_predictions = list(zip(top_k_cards, top_k_probs.tolist())) |
|
|
|
|
|
predicted_labels = (probabilities > 0.5).astype(int) |
|
if sum(predicted_labels) == 0: |
|
most_likely_card = "Answer" |
|
else: |
|
most_likely_card = self.decode_vector(predicted_labels.reshape(1, -1)) |
|
|
|
|
|
result = { |
|
"most_likely_card": most_likely_card, |
|
"top_k_predictions": top_k_predictions |
|
} |
|
|
|
return [result] |
|
|
|
def decode_vector(self, vector): |
|
original_label = self.encoder.inverse_transform(vector) |
|
return original_label[0][0] |
|
|
|
def handle(self, data, context): |
|
self.context = context |
|
data = self.preprocess(data) |
|
data = self.inference(data) |
|
data = self.postprocess(data) |
|
return data |
|
|