cai / app.py
msamogh's picture
Prettify OOS prediction
e15feb1
raw
history blame
2.44 kB
from pprint import pprint, pformat
import os
import gradio as gr
import click
from rasa.nlu.model import Interpreter
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Rasa intent + entity extractor
RASA_MODEL_PATH = "woz_nlu_agent/models/nlu"
interpreter = None
# OOS classifier
tokenizer = AutoTokenizer.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
model = AutoModelForSequenceClassification.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
MODEL_TYPES = {
"Out-of-scope classifier": "oos",
"Intent classifier": "intent_transformer",
"Intent and Entity extractor": "rasa_intent_entity"
}
def predict(model_type, input):
if MODEL_TYPES[model_type] == "rasa_intent_entity":
return rasa_predict(input)
elif MODEL_TYPES[model_type] == "oos":
return oos_predict(input)
elif MODEL_TYPES[model_type] == "intent_transformer":
return "TODO:: intent_transformer"
def oos_predict(input):
inputs = tokenizer(input, return_tensors="pt")
outputs = model(**inputs).logits
outputs = torch.softmax(torch.tensor(outputs), dim=-1)[0]
return str({"In scope": outputs[1], "Out of scope": outputs[0]})
def rasa_predict(input):
def rasa_output(text):
message = str(text).strip()
result = interpreter.parse(message)
return result
response = rasa_output(input)
del response["response_selector"]
response["intent_ranking"] = response["intent_ranking"][:3]
if "id" in response["intent"]:
del response["intent"]["id"]
for i in response["intent_ranking"]:
if "id" in i:
del i["id"]
for e in response["entities"]:
if "extractor" in e:
del e["extractor"]
if "start" in e and "end" in e:
del e["start"]
del e["end"]
return pformat(response, indent=4)
def main():
global interpreter
print("Loading model...")
print(os.listdir("woz_nlu_agent/models/nlu"))
print(open("woz_nlu_agent/models/nlu/metadata.json", "r").read())
import json
print(json.load(open("woz_nlu_agent/models/nlu/metadata.json", "r")))
interpreter = Interpreter.load(RASA_MODEL_PATH)
print("Model loaded.")
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Dropdown(list(MODEL_TYPES.keys())), "text"], outputs="text")
iface.launch()
if __name__ == "__main__":
main()