cai / app.py
msamogh's picture
Add import
4c94ed1
raw
history blame
2.33 kB
from pprint import pprint, pformat
import os
import gradio as gr
import click
from rasa.nlu.model import Interpreter
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Rasa intent + entity extractor
RASA_MODEL_PATH = "woz_nlu_agent/models/nlu"
interpreter = None
# OOS classifier
tokenizer = AutoTokenizer.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
model = AutoModelForSequenceClassification.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
MODEL_TYPES = {
"Out-of-scope classifier": "oos",
"Intent classifier": "intent_transformer",
"Intent and Entity extractor": "rasa_intent_entity"
}
def predict(model_type, input):
if MODEL_TYPES[model_type] == "rasa_intent_entity":
return rasa_predict(input)
elif MODEL_TYPES[model_type] == "oos":
return oos_predict(input)
elif MODEL_TYPES[model_type] == "intent_transformer":
return "TODO:: intent_transformer"
def oos_predict(input):
inputs = tokenizer(input, return_tensors="pt")
outputs = model(**inputs)
return str(outputs)
def rasa_predict(input):
def rasa_output(text):
message = str(text).strip()
result = interpreter.parse(message)
return result
response = rasa_output(input)
del response["response_selector"]
response["intent_ranking"] = response["intent_ranking"][:3]
if "id" in response["intent"]:
del response["intent"]["id"]
for i in response["intent_ranking"]:
if "id" in i:
del i["id"]
for e in response["entities"]:
if "extractor" in e:
del e["extractor"]
if "start" in e and "end" in e:
del e["start"]
del e["end"]
return pformat(response, indent=4)
def main():
global interpreter
print("Loading model...")
print(os.listdir("woz_nlu_agent/models/nlu"))
print(open("woz_nlu_agent/models/nlu/metadata.json", "r").read())
import json
print(json.load(open("woz_nlu_agent/models/nlu/metadata.json", "r")))
interpreter = Interpreter.load(RASA_MODEL_PATH)
print("Model loaded.")
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Dropdown(list(MODEL_TYPES.keys())), "text"], outputs="text")
iface.launch()
if __name__ == "__main__":
main()