import torch
from torchtext.data.utils import get_tokenizer
from model_arch import TextClassifierModel, load_state_dict

labels = {0: 'messaging',
         1: 'calling',
         2: 'event',
         3: 'timer',
         4: 'music',
         5: 'weather',
         6: 'alarm',
         7: 'people',
         8: 'reminder',
         9: 'recipes',
         10: 'news'}
    
model_trained = torch.load('model_checkpoint.pth')
vocab = torch.load('vocab.pt')
tokenizer = get_tokenizer("spacy", language="es")

text_pipeline = lambda x: vocab(tokenizer(x))

num_class = 11
vocab_size = len(vocab)
embed_size = 300

model = TextClassifierModel(vocab_size, embed_size, num_class)

model = load_state_dict(model, model_trained, vocab)

def predict(text, model=model, text_pipeline=text_pipeline):
    with torch.no_grad():
        model.eval()
        text_tensor = torch.tensor(text_pipeline(text))
        return labels[model(text_tensor, torch.tensor([0])).argmax(1).item()]