|
from pprint import pprint, pformat |
|
import os |
|
|
|
import gradio as gr |
|
import click |
|
from rasa.nlu.model import Interpreter |
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
RASA_MODEL_PATH = "woz_nlu_agent/models/nlu" |
|
interpreter = None |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116") |
|
model = AutoModelForSequenceClassification.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116") |
|
|
|
|
|
MODEL_TYPES = { |
|
"Out-of-scope classifier": "oos", |
|
"Intent classifier": "intent_transformer", |
|
"Intent and Entity extractor": "rasa_intent_entity" |
|
} |
|
|
|
def predict(model_type, input): |
|
if MODEL_TYPES[model_type] == "rasa_intent_entity": |
|
return rasa_predict(input) |
|
elif MODEL_TYPES[model_type] == "oos": |
|
return oos_predict(input) |
|
elif MODEL_TYPES[model_type] == "intent_transformer": |
|
return "TODO:: intent_transformer" |
|
|
|
|
|
def oos_predict(input): |
|
inputs = tokenizer(input, return_tensors="pt") |
|
outputs = model(**inputs) |
|
return str(outputs) |
|
|
|
|
|
def rasa_predict(input): |
|
|
|
def rasa_output(text): |
|
message = str(text).strip() |
|
result = interpreter.parse(message) |
|
return result |
|
|
|
response = rasa_output(input) |
|
|
|
del response["response_selector"] |
|
response["intent_ranking"] = response["intent_ranking"][:3] |
|
if "id" in response["intent"]: |
|
del response["intent"]["id"] |
|
for i in response["intent_ranking"]: |
|
if "id" in i: |
|
del i["id"] |
|
for e in response["entities"]: |
|
if "extractor" in e: |
|
del e["extractor"] |
|
if "start" in e and "end" in e: |
|
del e["start"] |
|
del e["end"] |
|
|
|
return pformat(response, indent=4) |
|
|
|
|
|
def main(): |
|
global interpreter |
|
print("Loading model...") |
|
print(os.listdir("woz_nlu_agent/models/nlu")) |
|
print(open("woz_nlu_agent/models/nlu/metadata.json", "r").read()) |
|
import json |
|
print(json.load(open("woz_nlu_agent/models/nlu/metadata.json", "r"))) |
|
|
|
interpreter = Interpreter.load(RASA_MODEL_PATH) |
|
print("Model loaded.") |
|
iface = gr.Interface(fn=predict, inputs=[gr.inputs.Dropdown(list(MODEL_TYPES.keys())), "text"], outputs="text") |
|
iface.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|