from pprint import pprint, pformat import os import gradio as gr import click from rasa.nlu.model import Interpreter from transformers import AutoTokenizer, AutoModelForSequenceClassification # Rasa intent + entity extractor RASA_MODEL_PATH = "woz_nlu_agent/models/nlu" interpreter = None # OOS classifier tokenizer = AutoTokenizer.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116") model = AutoModelForSequenceClassification.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116") MODEL_TYPES = { "Out-of-scope classifier": "oos", "Intent classifier": "intent_transformer", "Intent and Entity extractor": "rasa_intent_entity" } def predict(model_type, input): if MODEL_TYPES[model_type] == "rasa_intent_entity": return rasa_predict(input) elif MODEL_TYPES[model_type] == "oos": return oos_predict(input) elif MODEL_TYPES[model_type] == "intent_transformer": return "TODO:: intent_transformer" def oos_predict(input): inputs = tokenizer(input, return_tensors="pt") outputs = model(**inputs) return str(outputs) def rasa_predict(input): def rasa_output(text): message = str(text).strip() result = interpreter.parse(message) return result response = rasa_output(input) del response["response_selector"] response["intent_ranking"] = response["intent_ranking"][:3] if "id" in response["intent"]: del response["intent"]["id"] for i in response["intent_ranking"]: if "id" in i: del i["id"] for e in response["entities"]: if "extractor" in e: del e["extractor"] if "start" in e and "end" in e: del e["start"] del e["end"] return pformat(response, indent=4) def main(): global interpreter print("Loading model...") print(os.listdir("woz_nlu_agent/models/nlu")) print(open("woz_nlu_agent/models/nlu/metadata.json", "r").read()) import json print(json.load(open("woz_nlu_agent/models/nlu/metadata.json", "r"))) interpreter = Interpreter.load(RASA_MODEL_PATH) print("Model loaded.") iface = gr.Interface(fn=predict, inputs=[gr.inputs.Dropdown(list(MODEL_TYPES.keys())), "text"], outputs="text") iface.launch() if __name__ == "__main__": main()