Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- __init__.py +7 -0
- naive_chatbot.py +172 -0
__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Empty __init__.py file
|
2 |
+
|
3 |
+
# This file serves as a marker for the package and doesn't require any specific code
|
4 |
+
# for this particular project. The functionality resides within the `NaiveChatbot` class.
|
5 |
+
|
6 |
+
from naive_chatbot.naive_chatbot import NaiveChatbot
|
7 |
+
from naive_chatbot.retvec_chatbot import RetvecChatbot
|
naive_chatbot.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Naive Chatbot"""
|
3 |
+
import logging
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
import tensorflow as tf
|
7 |
+
from camel_tools.utils.normalize import normalize_unicode
|
8 |
+
from camel_tools.utils.normalize import normalize_alef_maksura_ar
|
9 |
+
from camel_tools.utils.normalize import normalize_alef_ar
|
10 |
+
from camel_tools.utils.normalize import normalize_teh_marbuta_ar
|
11 |
+
from keras.models import Sequential
|
12 |
+
from keras.layers import Dense, LSTM, Dropout, Embedding, Bidirectional
|
13 |
+
from keras.preprocessing.sequence import pad_sequences
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
"""A simple chatbot that utilizes an intent classifier then matching with predefined text mappings.
|
17 |
+
|
18 |
+
Typical usage example:
|
19 |
+
|
20 |
+
my_bot = NaiveChatbot(pretrained=True,
|
21 |
+
query_tokenizer_path="/../query_tokenizer.pickle",
|
22 |
+
intent_tokenizer_path="/../intent_tokenizer.pickle",
|
23 |
+
model_weights_path="/../checkpoint.ckpt",
|
24 |
+
db_responses2text_path="/../db_responses2text.pickle",
|
25 |
+
db_intent2response_path="/../db_intent2response.pickle",
|
26 |
+
db_stopwords_path="/../db_stopwords.pickle")
|
27 |
+
user_input = input("user > ")
|
28 |
+
print("bot > ", my_bot.get_reply(user_input))
|
29 |
+
"""
|
30 |
+
|
31 |
+
vocab_size = 500
|
32 |
+
embedding_dim = 128
|
33 |
+
max_length = 32
|
34 |
+
oov_tok = '<OOV>' # Out of Vocabulary
|
35 |
+
training_portion = 1
|
36 |
+
previous_reply = 'احنا لسه في بداية الكلام'
|
37 |
+
|
38 |
+
|
39 |
+
def load_pickle_data(filepath):
|
40 |
+
with open(filepath, 'rb') as pickle_file:
|
41 |
+
data = pickle.load(pickle_file)
|
42 |
+
return data
|
43 |
+
|
44 |
+
|
45 |
+
class NaiveChatbot:
|
46 |
+
|
47 |
+
def __get_model(self):
|
48 |
+
# TODO(mshetairy): Create a .gin for model hyperparameters
|
49 |
+
number_of_intents = len(self.intent_tokenizer.index_word.keys())
|
50 |
+
number_of_classes = number_of_intents + 1
|
51 |
+
model = Sequential(name="naive_chatbot")
|
52 |
+
model.add(Embedding(vocab_size, embedding_dim, input_length=max_length))
|
53 |
+
model.add(Dropout(0.5))
|
54 |
+
model.add(Bidirectional(LSTM(embedding_dim)))
|
55 |
+
model.add(Dense(number_of_classes, activation='softmax'))
|
56 |
+
logging.info(model.summary())
|
57 |
+
|
58 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=1e-6)
|
59 |
+
loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
60 |
+
model.compile(loss=loss,
|
61 |
+
optimizer=optimizer,
|
62 |
+
metrics=['accuracy'])
|
63 |
+
return model
|
64 |
+
|
65 |
+
def __init__(self,
|
66 |
+
pretrained: bool = False,
|
67 |
+
query_tokenizer_path: Optional[str] = None,
|
68 |
+
intent_tokenizer_path: Optional[str] = None,
|
69 |
+
model_weights_path: Optional[str] = None,
|
70 |
+
db_responses2text_path: Optional[str] = None,
|
71 |
+
db_intent2response_path: Optional[str] = None,
|
72 |
+
db_stopwords_path: Optional[str] = None,
|
73 |
+
db_transliteration_path: Optional[str] = None):
|
74 |
+
"""Initializing an instance of the chatbot.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
pretrained: If True loads required tokenizers and model weights.
|
78 |
+
query_tokenizer_path: path to the Arabic query Tokenizer.
|
79 |
+
intent_tokenizer_path: path to the Label Tokenizer of the user query's
|
80 |
+
intent.
|
81 |
+
model_weights_path: path to the pretrained intent classifier model
|
82 |
+
weights.
|
83 |
+
db_responses2text_path: path to the mapping of bot response type to
|
84 |
+
possible text outcomes.
|
85 |
+
db_intent2response_path: path to the mapping of user intents to
|
86 |
+
possible bot response types.
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
ValueError: An error occurred in the files paths.
|
90 |
+
"""
|
91 |
+
if pretrained:
|
92 |
+
if not all([query_tokenizer_path,
|
93 |
+
intent_tokenizer_path,
|
94 |
+
model_weights_path,
|
95 |
+
db_responses2text_path,
|
96 |
+
db_intent2response_path]):
|
97 |
+
raise ValueError("All arguments must be strings when pretrained is True.")
|
98 |
+
self.query_tokenizer = load_pickle_data(query_tokenizer_path)
|
99 |
+
self.intent_tokenizer = load_pickle_data(intent_tokenizer_path)
|
100 |
+
self.model = self.__get_model()
|
101 |
+
self.model.load_weights(model_weights_path).expect_partial()
|
102 |
+
self.db_responses2text = load_pickle_data(db_responses2text_path)
|
103 |
+
self.db_intent2response = load_pickle_data(db_intent2response_path)
|
104 |
+
# self.db_stopwords = load_pickle_data(db_stopwords_path)
|
105 |
+
self.db_transliteration = load_pickle_data(db_transliteration_path)
|
106 |
+
logging.info("Successfully loaded tokenizers, database and pretrained weights.")
|
107 |
+
else:
|
108 |
+
# Handle non-pretrained case if needed
|
109 |
+
# ...
|
110 |
+
pass
|
111 |
+
|
112 |
+
# Additional class attributes or methods
|
113 |
+
# ...
|
114 |
+
pass
|
115 |
+
|
116 |
+
def preprocess_query(self, query):
|
117 |
+
norm = normalize_unicode(query)
|
118 |
+
# Normalize alef variants to 'ا'
|
119 |
+
norm = normalize_alef_ar(norm)
|
120 |
+
# Normalize alef maksura 'ى' to yeh 'ي'
|
121 |
+
norm = normalize_alef_maksura_ar(norm)
|
122 |
+
# Normalize teh marbuta 'ة' to heh 'ه'
|
123 |
+
norm = normalize_teh_marbuta_ar(norm)
|
124 |
+
|
125 |
+
sent_safebw = self.db_transliteration(norm)
|
126 |
+
return sent_safebw
|
127 |
+
|
128 |
+
def __get_predictions(self, data):
|
129 |
+
"""Gets numerical model predictions."""
|
130 |
+
model = self.model
|
131 |
+
predictions = []
|
132 |
+
for i in range(0, len(data)):
|
133 |
+
prediction = model.predict(data[i, :].reshape(1, -1), verbose=0)
|
134 |
+
predictions.append(np.argmax(prediction))
|
135 |
+
return np.array(predictions)
|
136 |
+
|
137 |
+
def get_intent(self, text, threshold=0.4):
|
138 |
+
"""Classifies the intent behind the input text."""
|
139 |
+
intent_tokenizer = self.intent_tokenizer
|
140 |
+
model = self.model
|
141 |
+
query_tokenizer = self.query_tokenizer
|
142 |
+
# db_stopwords = self.db_stopwords
|
143 |
+
|
144 |
+
# for word in db_stopwords:
|
145 |
+
# token = ' ' + word + ' '
|
146 |
+
# text = text.replace(token, ' ')
|
147 |
+
# text = text.replace(' ', ' ')
|
148 |
+
norm = self.preprocess_query(text)
|
149 |
+
seq = query_tokenizer.texts_to_sequences([norm])
|
150 |
+
padded = pad_sequences(seq, maxlen=max_length)
|
151 |
+
pred = model.predict(padded, verbose=0)
|
152 |
+
|
153 |
+
try:
|
154 |
+
if np.max(pred) < threshold:
|
155 |
+
label = ['']
|
156 |
+
else:
|
157 |
+
label = intent_tokenizer.sequences_to_texts(np.array([[np.argmax(pred)]]))
|
158 |
+
label = ['other'] if label == [''] else label
|
159 |
+
answer = label
|
160 |
+
except:
|
161 |
+
answer = ['other']
|
162 |
+
return answer
|
163 |
+
|
164 |
+
def get_reply(self, text, threshold=0.4):
|
165 |
+
global previous_reply
|
166 |
+
intent = self.get_intent(text, threshold)[0]
|
167 |
+
if intent == "request_repeat":
|
168 |
+
return previous_reply
|
169 |
+
response_type = np.random.choice(self.db_intent2response[intent])
|
170 |
+
reply = np.random.choice(self.db_responses2text[response_type])
|
171 |
+
previous_reply = reply
|
172 |
+
return reply
|