mshetairy commited on
Commit
6a4e037
·
verified ·
1 Parent(s): d0ee69a

Upload 2 files

Browse files
Files changed (2) hide show
  1. __init__.py +7 -0
  2. naive_chatbot.py +172 -0
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Empty __init__.py file
2
+
3
+ # This file serves as a marker for the package and doesn't require any specific code
4
+ # for this particular project. The functionality resides within the `NaiveChatbot` class.
5
+
6
+ from naive_chatbot.naive_chatbot import NaiveChatbot
7
+ from naive_chatbot.retvec_chatbot import RetvecChatbot
naive_chatbot.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Naive Chatbot"""
3
+ import logging
4
+ import pickle
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from camel_tools.utils.normalize import normalize_unicode
8
+ from camel_tools.utils.normalize import normalize_alef_maksura_ar
9
+ from camel_tools.utils.normalize import normalize_alef_ar
10
+ from camel_tools.utils.normalize import normalize_teh_marbuta_ar
11
+ from keras.models import Sequential
12
+ from keras.layers import Dense, LSTM, Dropout, Embedding, Bidirectional
13
+ from keras.preprocessing.sequence import pad_sequences
14
+ from typing import Optional
15
+
16
+ """A simple chatbot that utilizes an intent classifier then matching with predefined text mappings.
17
+
18
+ Typical usage example:
19
+
20
+ my_bot = NaiveChatbot(pretrained=True,
21
+ query_tokenizer_path="/../query_tokenizer.pickle",
22
+ intent_tokenizer_path="/../intent_tokenizer.pickle",
23
+ model_weights_path="/../checkpoint.ckpt",
24
+ db_responses2text_path="/../db_responses2text.pickle",
25
+ db_intent2response_path="/../db_intent2response.pickle",
26
+ db_stopwords_path="/../db_stopwords.pickle")
27
+ user_input = input("user > ")
28
+ print("bot > ", my_bot.get_reply(user_input))
29
+ """
30
+
31
+ vocab_size = 500
32
+ embedding_dim = 128
33
+ max_length = 32
34
+ oov_tok = '<OOV>' # Out of Vocabulary
35
+ training_portion = 1
36
+ previous_reply = 'احنا لسه في بداية الكلام'
37
+
38
+
39
+ def load_pickle_data(filepath):
40
+ with open(filepath, 'rb') as pickle_file:
41
+ data = pickle.load(pickle_file)
42
+ return data
43
+
44
+
45
+ class NaiveChatbot:
46
+
47
+ def __get_model(self):
48
+ # TODO(mshetairy): Create a .gin for model hyperparameters
49
+ number_of_intents = len(self.intent_tokenizer.index_word.keys())
50
+ number_of_classes = number_of_intents + 1
51
+ model = Sequential(name="naive_chatbot")
52
+ model.add(Embedding(vocab_size, embedding_dim, input_length=max_length))
53
+ model.add(Dropout(0.5))
54
+ model.add(Bidirectional(LSTM(embedding_dim)))
55
+ model.add(Dense(number_of_classes, activation='softmax'))
56
+ logging.info(model.summary())
57
+
58
+ optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=1e-6)
59
+ loss = tf.keras.losses.SparseCategoricalCrossentropy()
60
+ model.compile(loss=loss,
61
+ optimizer=optimizer,
62
+ metrics=['accuracy'])
63
+ return model
64
+
65
+ def __init__(self,
66
+ pretrained: bool = False,
67
+ query_tokenizer_path: Optional[str] = None,
68
+ intent_tokenizer_path: Optional[str] = None,
69
+ model_weights_path: Optional[str] = None,
70
+ db_responses2text_path: Optional[str] = None,
71
+ db_intent2response_path: Optional[str] = None,
72
+ db_stopwords_path: Optional[str] = None,
73
+ db_transliteration_path: Optional[str] = None):
74
+ """Initializing an instance of the chatbot.
75
+
76
+ Args:
77
+ pretrained: If True loads required tokenizers and model weights.
78
+ query_tokenizer_path: path to the Arabic query Tokenizer.
79
+ intent_tokenizer_path: path to the Label Tokenizer of the user query's
80
+ intent.
81
+ model_weights_path: path to the pretrained intent classifier model
82
+ weights.
83
+ db_responses2text_path: path to the mapping of bot response type to
84
+ possible text outcomes.
85
+ db_intent2response_path: path to the mapping of user intents to
86
+ possible bot response types.
87
+
88
+ Raises:
89
+ ValueError: An error occurred in the files paths.
90
+ """
91
+ if pretrained:
92
+ if not all([query_tokenizer_path,
93
+ intent_tokenizer_path,
94
+ model_weights_path,
95
+ db_responses2text_path,
96
+ db_intent2response_path]):
97
+ raise ValueError("All arguments must be strings when pretrained is True.")
98
+ self.query_tokenizer = load_pickle_data(query_tokenizer_path)
99
+ self.intent_tokenizer = load_pickle_data(intent_tokenizer_path)
100
+ self.model = self.__get_model()
101
+ self.model.load_weights(model_weights_path).expect_partial()
102
+ self.db_responses2text = load_pickle_data(db_responses2text_path)
103
+ self.db_intent2response = load_pickle_data(db_intent2response_path)
104
+ # self.db_stopwords = load_pickle_data(db_stopwords_path)
105
+ self.db_transliteration = load_pickle_data(db_transliteration_path)
106
+ logging.info("Successfully loaded tokenizers, database and pretrained weights.")
107
+ else:
108
+ # Handle non-pretrained case if needed
109
+ # ...
110
+ pass
111
+
112
+ # Additional class attributes or methods
113
+ # ...
114
+ pass
115
+
116
+ def preprocess_query(self, query):
117
+ norm = normalize_unicode(query)
118
+ # Normalize alef variants to 'ا'
119
+ norm = normalize_alef_ar(norm)
120
+ # Normalize alef maksura 'ى' to yeh 'ي'
121
+ norm = normalize_alef_maksura_ar(norm)
122
+ # Normalize teh marbuta 'ة' to heh 'ه'
123
+ norm = normalize_teh_marbuta_ar(norm)
124
+
125
+ sent_safebw = self.db_transliteration(norm)
126
+ return sent_safebw
127
+
128
+ def __get_predictions(self, data):
129
+ """Gets numerical model predictions."""
130
+ model = self.model
131
+ predictions = []
132
+ for i in range(0, len(data)):
133
+ prediction = model.predict(data[i, :].reshape(1, -1), verbose=0)
134
+ predictions.append(np.argmax(prediction))
135
+ return np.array(predictions)
136
+
137
+ def get_intent(self, text, threshold=0.4):
138
+ """Classifies the intent behind the input text."""
139
+ intent_tokenizer = self.intent_tokenizer
140
+ model = self.model
141
+ query_tokenizer = self.query_tokenizer
142
+ # db_stopwords = self.db_stopwords
143
+
144
+ # for word in db_stopwords:
145
+ # token = ' ' + word + ' '
146
+ # text = text.replace(token, ' ')
147
+ # text = text.replace(' ', ' ')
148
+ norm = self.preprocess_query(text)
149
+ seq = query_tokenizer.texts_to_sequences([norm])
150
+ padded = pad_sequences(seq, maxlen=max_length)
151
+ pred = model.predict(padded, verbose=0)
152
+
153
+ try:
154
+ if np.max(pred) < threshold:
155
+ label = ['']
156
+ else:
157
+ label = intent_tokenizer.sequences_to_texts(np.array([[np.argmax(pred)]]))
158
+ label = ['other'] if label == [''] else label
159
+ answer = label
160
+ except:
161
+ answer = ['other']
162
+ return answer
163
+
164
+ def get_reply(self, text, threshold=0.4):
165
+ global previous_reply
166
+ intent = self.get_intent(text, threshold)[0]
167
+ if intent == "request_repeat":
168
+ return previous_reply
169
+ response_type = np.random.choice(self.db_intent2response[intent])
170
+ reply = np.random.choice(self.db_responses2text[response_type])
171
+ previous_reply = reply
172
+ return reply