Fix typo
Browse files
app.py
CHANGED
@@ -6,9 +6,15 @@ import click
|
|
6 |
from rasa.nlu.model import Interpreter
|
7 |
|
8 |
|
|
|
9 |
RASA_MODEL_PATH = "woz_nlu_agent/models/nlu"
|
10 |
interpreter = None
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
MODEL_TYPES = {
|
13 |
"Out-of-scope classifier": "oos",
|
14 |
"Intent classifier": "intent_transformer",
|
@@ -19,10 +25,17 @@ def predict(model_type, input):
|
|
19 |
if MODEL_TYPES[model_type] == "rasa_intent_entity":
|
20 |
return rasa_predict(input)
|
21 |
elif MODEL_TYPES[model_type] == "oos":
|
22 |
-
return
|
23 |
elif MODEL_TYPES[model_type] == "intent_transformer":
|
24 |
return "TODO:: intent_transformer"
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def rasa_predict(input):
|
28 |
|
|
|
6 |
from rasa.nlu.model import Interpreter
|
7 |
|
8 |
|
9 |
+
# Rasa intent + entity extractor
|
10 |
RASA_MODEL_PATH = "woz_nlu_agent/models/nlu"
|
11 |
interpreter = None
|
12 |
|
13 |
+
# OOS classifier
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
|
15 |
+
model = AutoModelForSequenceClassification.from_pretrained("msamogh/autonlp-cai-out-of-scope-649919116")
|
16 |
+
|
17 |
+
|
18 |
MODEL_TYPES = {
|
19 |
"Out-of-scope classifier": "oos",
|
20 |
"Intent classifier": "intent_transformer",
|
|
|
25 |
if MODEL_TYPES[model_type] == "rasa_intent_entity":
|
26 |
return rasa_predict(input)
|
27 |
elif MODEL_TYPES[model_type] == "oos":
|
28 |
+
return oos_predict(input)
|
29 |
elif MODEL_TYPES[model_type] == "intent_transformer":
|
30 |
return "TODO:: intent_transformer"
|
31 |
|
32 |
+
|
33 |
+
def oos_predict(input):
|
34 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
35 |
+
inputs = tokenizer(input, return_tensors="pt")
|
36 |
+
outputs = model(**inputs)
|
37 |
+
return str(outputs)
|
38 |
+
|
39 |
|
40 |
def rasa_predict(input):
|
41 |
|