import os, sys myPath = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, myPath + '/../') # ========== import torch from ercbcm.model_loader import load from ercbcm.ERCBCM import ERCBCM from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ========== model_for_predict = ERCBCM().to(device) load('ercbcm/model.pt', model_for_predict, device) def predict(sentence, name): label = torch.tensor([0]) label = label.type(torch.LongTensor) label = label.to(device) text = tokenizer.encode(normalize_v2(sentence, name)) text += [PAD_TOKEN_ID] * (128 - len(text)) text = torch.tensor([text]) text = text.type(torch.LongTensor) text = text.to(device) _, output = model_for_predict(text, label) pred = torch.argmax(output, 1).tolist()[0] return 'CALLING' if pred == 1 else 'MENTIONING'