d-c-t's picture
Create README.md
3ea94e7 verified

Here's an adapted TWIZ intent detection model, trained on the TWIZ dataset, with an extra focus on simplicity!
It achieves ~85% accuracy on the TWIZ test set, and should be especially useful for the WSDM students @ NOVA.

I STRONGLY suggest interested students to check model_code in the Files and versions tab, where all the code used to get to the model (with the exception of actually uploading it here) is laid out nicely (I hope!)

Here's the contents of intent-detection-example.ipynb, if you're just looking to use the model:

with open("twiz-data/all_intents.json", 'r') as json_in: # all_intents.json can be found in the task-intent-detector/model_code directory
    data = json.load(json_in)

id_to_intent, intent_to_id = dict(), dict()
for i, intent in enumerate(data):
    id_to_intent[i] = intent
    intent_to_id[intent] = i

model = AutoModelForSequenceClassification.from_pretrained("NOVA-vision-language/task-intent-detector", num_labels=len(data), id2label=id_to_intent, label2id=intent_to_id)
tokenizer = AutoTokenizer.from_pretrained("roberta-base") # you could try 'NOVA-vision-language/task-intent-detector', but I'm not sure I configured it correctly

model_in = tokenizer("I really really wanna go to the next step", return_tensors='pt')
with torch.no_grad():
    logits = model(**model_in).logits # grab the predictions out of the model's classification head
    predicted_class_id = logits.argmax().item() # grab the index of the highest scoring output
    print(model.config.id2label[predicted_class_id]) # use the translation table we just created to translate between that id and the actual intent name