from pprint import pprint, pformat import os import gradio as gr import click import torch from rasa.nlu.model import Interpreter from transformers import AutoTokenizer, AutoModelForSequenceClassification # Rasa intent + entity extractor RASA_MODEL_PATH = "woz_nlu_agent/models/nlu-lookup-1" interpreter = Interpreter.load(RASA_MODEL_PATH) # OOS classifier OOS_MODEL= "msamogh/autonlp-cai-out-of-scope-649919116" tokenizer = AutoTokenizer.from_pretrained(OOS_MODEL) model = AutoModelForSequenceClassification.from_pretrained(OOS_MODEL) MODEL_TYPES = { "Out-of-scope classifier": "oos", "Intent classifier": "intent_transformer", "Intent and Entity extractor": "rasa_intent_entity" } def predict(model_type, input): if MODEL_TYPES[model_type] == "rasa_intent_entity": return rasa_predict(input) elif MODEL_TYPES[model_type] == "oos": return oos_predict(input) elif MODEL_TYPES[model_type] == "intent_transformer": return "WIP: intent_transformer" def oos_predict(input): inputs = tokenizer(input, return_tensors="pt") outputs = model(**inputs).logits outputs = torch.softmax(torch.tensor(outputs), dim=-1)[0] return str({"In scope": outputs[1], "Out of scope": outputs[0]}) def rasa_predict(input): def rasa_output(text): message = str(text).strip() result = interpreter.parse(message) return result response = rasa_output(input) del response["response_selector"] response["intent_ranking"] = response["intent_ranking"][:3] if "id" in response["intent"]: del response["intent"]["id"] for i in response["intent_ranking"]: if "id" in i: del i["id"] for e in response["entities"]: if "extractor" in e: del e["extractor"] if "start" in e and "end" in e: del e["start"] del e["end"] if "confidence_entity" in e: del e["confidence_entity"] if "intent_ranking" in response: del response["intent_ranking"] if "intent" in response: response["intent"] = response["intent"]["name"] return pformat(response, indent=4) def main(): iface = gr.Interface(fn=predict, inputs=[gr.inputs.Dropdown(list(MODEL_TYPES.keys())), "text"], outputs="text") iface.launch() if __name__ == "__main__": main()