Upload 9 files
Browse files- README.md +51 -12
- b_bot.py +76 -0
- chat.py +105 -0
- data.pth +3 -0
- model.py +22 -0
- nltk_utils.py +48 -0
- requirements.txt +8 -0
- spell_check.py +35 -0
- train.py +148 -0
README.md
CHANGED
@@ -1,12 +1,51 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Hosted Link:
|
2 |
+
https://bibekbot.streamlit.app/
|
3 |
+
|
4 |
+
courses? --response later on
|
5 |
+
|
6 |
+
## Setup:
|
7 |
+
Clone the repo
|
8 |
+
```console
|
9 |
+
git clone [email protected]:bibekyess/Personal-Chatbot.git
|
10 |
+
cd Personal-Chatbot
|
11 |
+
```
|
12 |
+
|
13 |
+
### Create an environment using pipenv
|
14 |
+
Install pipenv if not installed. [Pipenv is required to install streamlit in macOS]
|
15 |
+
```console
|
16 |
+
pip3 install pipenv
|
17 |
+
```
|
18 |
+
|
19 |
+
Creates a new Pipenv environment using python-3.9 and activates it
|
20 |
+
```console
|
21 |
+
pipenv --python 3.9
|
22 |
+
pipenv shell
|
23 |
+
```
|
24 |
+
|
25 |
+
Installs streamlit in the recently created environment
|
26 |
+
```console
|
27 |
+
pipenv install streamlit==1.11.1
|
28 |
+
```
|
29 |
+
|
30 |
+
Installs the dependencies mentioned in requirements.txt file
|
31 |
+
```console
|
32 |
+
pip install -r requirements.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
Runs the b_bot app
|
36 |
+
```console
|
37 |
+
streamlit run b_bot.py
|
38 |
+
```
|
39 |
+
|
40 |
+
# For training:
|
41 |
+
```console
|
42 |
+
python train.py
|
43 |
+
```
|
44 |
+
This saves the checkpoint in 'data.pth' file. Then run the b_bot app
|
45 |
+
```console
|
46 |
+
streamlit run b_bot.py
|
47 |
+
```
|
48 |
+
|
49 |
+
Reference:
|
50 |
+
I referred to https://github.com/patrickloeber/pytorch-chatbot for simple implementation of contextual chatbot with interactions from terminals.
|
51 |
+
I referred to https://github.com/AI-Yash/st-chat for the beautiful chatbot interface.
|
b_bot.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit_chat import message
|
5 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
6 |
+
|
7 |
+
from chat import generate_response
|
8 |
+
|
9 |
+
if "tokenizer" not in st.session_state:
|
10 |
+
st.session_state["tokenizer"] = AutoTokenizer.from_pretrained(
|
11 |
+
"./generative_model/LaMini-Flan-T5-783M"
|
12 |
+
)
|
13 |
+
st.session_state["model"] = AutoModelForSeq2SeqLM.from_pretrained(
|
14 |
+
"./generative_model/LaMini-Flan-T5-783M"
|
15 |
+
)
|
16 |
+
|
17 |
+
st.title("B-Bot : Bibek's Personal Chatbot")
|
18 |
+
# Storing the chat
|
19 |
+
if "generated" not in st.session_state:
|
20 |
+
st.session_state["generated"] = []
|
21 |
+
|
22 |
+
if "past" not in st.session_state:
|
23 |
+
st.session_state["past"] = []
|
24 |
+
|
25 |
+
|
26 |
+
# We will get the user's input by calling the get_text function
|
27 |
+
def get_text():
|
28 |
+
input_text = st.text_input("Enter your inquiries here: ", "Hi!!")
|
29 |
+
return input_text
|
30 |
+
|
31 |
+
|
32 |
+
user_input = get_text()
|
33 |
+
|
34 |
+
if user_input:
|
35 |
+
tokenizer = st.session_state["tokenizer"]
|
36 |
+
model = st.session_state["model"]
|
37 |
+
output = generate_response(user_input)
|
38 |
+
prompt_template = "\nPlease make meaningful sentence and try to be descriptive as possible, ending with proper punctuations. If you don't have descriptive answers from the available prompt, write sorry and advise them to contact Bibek directly." # NoQA
|
39 |
+
short_response_template = "\nIf your response is very short like 1 or 2 sentence, add a followup sentence like 'Let me know if there's anything else I can help you with. or If there's anything else I can assist with, please don't hesitate to ask. I mean something similar in polite way." # NoQA
|
40 |
+
|
41 |
+
start = time.time()
|
42 |
+
input_ids = tokenizer(
|
43 |
+
output + user_input + prompt_template + short_response_template,
|
44 |
+
return_tensors="pt",
|
45 |
+
).input_ids
|
46 |
+
|
47 |
+
outputs = model.generate(input_ids, max_length=512, do_sample=True)
|
48 |
+
output = tokenizer.decode(outputs[0]).strip('<pad></s>').strip()
|
49 |
+
end = time.time()
|
50 |
+
|
51 |
+
print("Time for model inference: ", end - start)
|
52 |
+
# Checks for memory overflow
|
53 |
+
if len(st.session_state.past) == 15:
|
54 |
+
st.session_state.past.pop(0)
|
55 |
+
st.session_state.generated.pop(0)
|
56 |
+
|
57 |
+
# store the output
|
58 |
+
st.session_state.past.append(user_input)
|
59 |
+
st.session_state.generated.append(output)
|
60 |
+
|
61 |
+
if st.session_state["generated"]:
|
62 |
+
# print(st.session_state)
|
63 |
+
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
|
64 |
+
message(
|
65 |
+
st.session_state["generated"][i],
|
66 |
+
avatar_style="bottts",
|
67 |
+
seed=39,
|
68 |
+
key=str(i), # NoQA
|
69 |
+
)
|
70 |
+
message(
|
71 |
+
st.session_state["past"][i],
|
72 |
+
is_user=True,
|
73 |
+
avatar_style="identicon",
|
74 |
+
seed=4,
|
75 |
+
key=str(i) + "_user",
|
76 |
+
) # NoQA
|
chat.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from model import NeuralNet
|
7 |
+
from nltk_utils import bag_of_words, tokenize
|
8 |
+
from spell_check import correct_typos
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
with open("intents.json") as json_data:
|
13 |
+
intents = json.load(json_data)
|
14 |
+
|
15 |
+
FILE = "data.pth"
|
16 |
+
data = torch.load(FILE)
|
17 |
+
|
18 |
+
input_size = data["input_size"]
|
19 |
+
hidden_size = data["hidden_size"]
|
20 |
+
output_size = data["output_size"]
|
21 |
+
all_words = data["all_words"]
|
22 |
+
tags = data["tags"]
|
23 |
+
model_state = data["model_state"]
|
24 |
+
|
25 |
+
model = NeuralNet(input_size, hidden_size, output_size).to(device)
|
26 |
+
model.load_state_dict(model_state)
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
bot_name = "B-Bot"
|
30 |
+
# print(
|
31 |
+
# "Hello, I am B-BOT, personal ChatBOT of Mr. Bibek. Let's chat! (type 'quit' or 'q' to exit)" # NoQA
|
32 |
+
# )
|
33 |
+
|
34 |
+
|
35 |
+
def generate_response(sentence):
|
36 |
+
# sentence = input("You: ")
|
37 |
+
sentence = correct_typos(sentence)
|
38 |
+
# print(sentence)
|
39 |
+
if sentence.lower() == "quit" or sentence.lower() == "q":
|
40 |
+
# Needs to quit
|
41 |
+
pass
|
42 |
+
|
43 |
+
sentence = tokenize(sentence)
|
44 |
+
X = bag_of_words(sentence, all_words)
|
45 |
+
X = X.reshape(1, X.shape[0])
|
46 |
+
X = torch.from_numpy(X).to(device)
|
47 |
+
|
48 |
+
output = model(X)
|
49 |
+
_, predicted = torch.max(output, dim=1)
|
50 |
+
|
51 |
+
tag = tags[predicted.item()]
|
52 |
+
|
53 |
+
probs = torch.softmax(output, dim=1)
|
54 |
+
prob = probs[0][predicted.item()]
|
55 |
+
print(prob.item())
|
56 |
+
if prob.item() > 0.95:
|
57 |
+
for intent in intents["intents"]:
|
58 |
+
if tag == intent["tag"]:
|
59 |
+
return f"{bot_name}: {random.choice(intent['responses'])}"
|
60 |
+
else:
|
61 |
+
return (
|
62 |
+
f"{bot_name}: Sorry, I didn't understand... Can you be more "
|
63 |
+
"specific on your question? You can ask about Bibek's skillset, "
|
64 |
+
"experiences, portfolio, education, achievements "
|
65 |
+
"and KAIST activities."
|
66 |
+
"These are some sample questions: "
|
67 |
+
"(I) Tell me about Bibek,\n"
|
68 |
+
"(II) What skills does he have?,\n"
|
69 |
+
"(III) What work experience does Bibek have?,\n"
|
70 |
+
"(IV) What is Bibek's educational background?,\n"
|
71 |
+
"(V) What awards has he won?,\n"
|
72 |
+
"(VI) What projects has he completed? &\n"
|
73 |
+
"(VII) How can I contact Bibek?"
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
# while True:
|
78 |
+
# # sentence = "do you use credit cards?"
|
79 |
+
# sentence = input("You: ")
|
80 |
+
# if sentence.lower() == "quit" or sentence.lower() == "q":
|
81 |
+
# break
|
82 |
+
|
83 |
+
# sentence = tokenize(sentence)
|
84 |
+
# X = bag_of_words(sentence, all_words)
|
85 |
+
# X = X.reshape(1, X.shape[0])
|
86 |
+
# X = torch.from_numpy(X).to(device)
|
87 |
+
|
88 |
+
# output = model(X)
|
89 |
+
# _, predicted = torch.max(output, dim=1)
|
90 |
+
|
91 |
+
# tag = tags[predicted.item()]
|
92 |
+
|
93 |
+
# probs = torch.softmax(output, dim=1)
|
94 |
+
# prob = probs[0][predicted.item()]
|
95 |
+
# if prob.item() > 0.8:
|
96 |
+
# for intent in intents["intents"]:
|
97 |
+
# if tag == intent["tag"]:
|
98 |
+
# print(f"{bot_name}: {random.choice(intent['responses'])}")
|
99 |
+
# else:
|
100 |
+
# print(
|
101 |
+
# f"{bot_name}: Sorry, I do not understand... Can you be more "
|
102 |
+
# "specific on your question? You can ask about Bibek's skillset, "
|
103 |
+
# "experiences, portfolio, education, achievements "
|
104 |
+
# "and KAIST activities."
|
105 |
+
# )
|
data.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edbf28e6717a7af532466dda112419472c7b8fb747dd9a10105a3aeb117ddce4
|
3 |
+
size 102591
|
model.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class NeuralNet(nn.Module):
|
5 |
+
def __init__(self, input_size, hidden_size, num_classes):
|
6 |
+
super().__init__()
|
7 |
+
self.l1 = nn.Linear(input_size, hidden_size)
|
8 |
+
self.l2 = nn.Linear(hidden_size, hidden_size)
|
9 |
+
self.l3 = nn.Linear(hidden_size, num_classes)
|
10 |
+
self.relu = nn.ReLU()
|
11 |
+
self.dropout = nn.Dropout(p=0.5)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
out = self.l1(x)
|
15 |
+
out = self.relu(out)
|
16 |
+
out = self.dropout(out)
|
17 |
+
out = self.l2(out)
|
18 |
+
out = self.relu(out)
|
19 |
+
out = self.dropout(out)
|
20 |
+
out = self.l3(out)
|
21 |
+
# no activation and no softmax at the end
|
22 |
+
return out
|
nltk_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import numpy as np
|
3 |
+
from nltk.stem.porter import PorterStemmer
|
4 |
+
|
5 |
+
# package with a pretrained tokenizer, may need to uncomment the following
|
6 |
+
# to download for the first time
|
7 |
+
# nltk.download('punkt')
|
8 |
+
|
9 |
+
stemmer = PorterStemmer()
|
10 |
+
|
11 |
+
|
12 |
+
def tokenize(sentence):
|
13 |
+
"""
|
14 |
+
split sentence into array of words/tokens
|
15 |
+
a token can be a word or punctuation character, or number
|
16 |
+
"""
|
17 |
+
return nltk.word_tokenize(sentence)
|
18 |
+
|
19 |
+
|
20 |
+
def stem(word):
|
21 |
+
"""
|
22 |
+
stemming = find the root form of the word
|
23 |
+
examples:
|
24 |
+
words = ["organize", "organizes", "organizing"]
|
25 |
+
words = [stem(w) for w in words]
|
26 |
+
-> ["organ", "organ", "organ"]
|
27 |
+
"""
|
28 |
+
return stemmer.stem(word.lower())
|
29 |
+
|
30 |
+
|
31 |
+
def bag_of_words(tokenized_sentence, words):
|
32 |
+
"""
|
33 |
+
return bag of words array:
|
34 |
+
1 for each known word that exists in the sentence, 0 otherwise
|
35 |
+
example:
|
36 |
+
sentence = ["hello", "how", "are", "you"]
|
37 |
+
words = ["hi", "hello", "I", "you", "bye", "thank", "cool"]
|
38 |
+
bog = [ 0 , 1 , 0 , 1 , 0 , 0 , 0]
|
39 |
+
"""
|
40 |
+
# stem each word
|
41 |
+
sentence_words = [stem(word) for word in tokenized_sentence]
|
42 |
+
# initialize bag with 0 for each word
|
43 |
+
bag = np.zeros(len(words), dtype=np.float32)
|
44 |
+
for idx, w in enumerate(words):
|
45 |
+
if w in sentence_words:
|
46 |
+
bag[idx] = 1
|
47 |
+
|
48 |
+
return bag
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk
|
2 |
+
pyspellchecker
|
3 |
+
streamlit-chat
|
4 |
+
torch
|
5 |
+
torchaudio
|
6 |
+
torchvision
|
7 |
+
huggingface_hub
|
8 |
+
transformers
|
spell_check.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spellchecker
|
2 |
+
|
3 |
+
|
4 |
+
def correct_typos(sentence):
|
5 |
+
# Initialize the spell checker object
|
6 |
+
spell = spellchecker.SpellChecker(language="en")
|
7 |
+
# Adds Bibek to its frequency dictionary to make it a known word
|
8 |
+
spell.word_frequency.load_words(
|
9 |
+
[
|
10 |
+
"Bibek",
|
11 |
+
"Bibek's",
|
12 |
+
"skillsets",
|
13 |
+
"skillset",
|
14 |
+
"CV",
|
15 |
+
"RIRO",
|
16 |
+
"Bisonai",
|
17 |
+
"IC",
|
18 |
+
"BMC",
|
19 |
+
"KAIST",
|
20 |
+
]
|
21 |
+
)
|
22 |
+
sentence_split = sentence.split()
|
23 |
+
# Find the typos in the input sentence
|
24 |
+
typos = spell.unknown(sentence_split)
|
25 |
+
# Correct the typos
|
26 |
+
corrected_sentence = [
|
27 |
+
spell.correction(word)
|
28 |
+
if spell.correction(word)
|
29 |
+
else word
|
30 |
+
if word in typos
|
31 |
+
else word
|
32 |
+
for word in sentence_split
|
33 |
+
]
|
34 |
+
# Return the corrected sentence as a string
|
35 |
+
return " ".join(corrected_sentence)
|
train.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
|
8 |
+
from model import NeuralNet
|
9 |
+
from nltk_utils import bag_of_words, stem, tokenize
|
10 |
+
|
11 |
+
with open("intents.json") as f:
|
12 |
+
intents = json.load(f)
|
13 |
+
|
14 |
+
all_words = []
|
15 |
+
tags = []
|
16 |
+
xy = []
|
17 |
+
# loop through each sentence in our intents patterns
|
18 |
+
for intent in intents["intents"]:
|
19 |
+
tag = intent["tag"]
|
20 |
+
# add to tag list
|
21 |
+
tags.append(tag)
|
22 |
+
for pattern in intent["patterns"]:
|
23 |
+
# tokenize each word in the sentence
|
24 |
+
w = tokenize(pattern)
|
25 |
+
# add to our words list
|
26 |
+
all_words.extend(w)
|
27 |
+
# add to xy pair
|
28 |
+
xy.append((w, tag))
|
29 |
+
AUGMENT = False
|
30 |
+
if "Bibek" in pattern:
|
31 |
+
pattern = pattern.replace("Bibek", "he")
|
32 |
+
AUGMENT = True
|
33 |
+
elif "bibek" in pattern:
|
34 |
+
pattern = pattern.replace("bibek", "he")
|
35 |
+
AUGMENT = True
|
36 |
+
elif "BIBEK" in pattern:
|
37 |
+
pattern = pattern.replace("BIBEK", "he")
|
38 |
+
AUGMENT = True
|
39 |
+
if AUGMENT:
|
40 |
+
w = tokenize(pattern)
|
41 |
+
all_words.extend(w)
|
42 |
+
xy.append((w, tag))
|
43 |
+
|
44 |
+
# stem and lower each word
|
45 |
+
ignore_words = ["?", ".", "!"]
|
46 |
+
all_words = [stem(w) for w in all_words if w not in ignore_words]
|
47 |
+
# remove duplicates and sort
|
48 |
+
all_words = sorted(set(all_words))
|
49 |
+
tags = sorted(set(tags))
|
50 |
+
|
51 |
+
print(len(xy), "patterns")
|
52 |
+
print(len(tags), "tags:", tags)
|
53 |
+
print(len(all_words), "unique stemmed words:", all_words)
|
54 |
+
|
55 |
+
# create training data
|
56 |
+
X_train = []
|
57 |
+
y_train = []
|
58 |
+
for (pattern_sentence, tag) in xy:
|
59 |
+
# X: bag of words for each pattern_sentence
|
60 |
+
bag = bag_of_words(pattern_sentence, all_words)
|
61 |
+
X_train.append(bag)
|
62 |
+
# y: PyTorch CrossEntropyLoss needs only class labels, not one-hot
|
63 |
+
label = tags.index(tag)
|
64 |
+
y_train.append(label)
|
65 |
+
|
66 |
+
X_train = np.array(X_train)
|
67 |
+
y_train = np.array(y_train)
|
68 |
+
|
69 |
+
# Hyper-parameters
|
70 |
+
num_epochs = 1000
|
71 |
+
batch_size = 32
|
72 |
+
learning_rate = 0.001
|
73 |
+
input_size = len(X_train[0])
|
74 |
+
hidden_size = 64
|
75 |
+
num_heads = 8
|
76 |
+
num_layer = 6
|
77 |
+
output_size = len(tags)
|
78 |
+
print(input_size, output_size)
|
79 |
+
|
80 |
+
|
81 |
+
class ChatDataset(Dataset):
|
82 |
+
"""
|
83 |
+
Creates PyTorch dataset to automatically iterate and do batch training
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self):
|
87 |
+
self.n_samples = len(X_train)
|
88 |
+
self.x_data = X_train
|
89 |
+
self.y_data = y_train
|
90 |
+
|
91 |
+
# support indexing such that dataset[i] can be used to get i-th sample
|
92 |
+
def __getitem__(self, index):
|
93 |
+
return self.x_data[index], self.y_data[index]
|
94 |
+
|
95 |
+
# we can call len(dataset) to return the size
|
96 |
+
def __len__(self):
|
97 |
+
return self.n_samples
|
98 |
+
|
99 |
+
|
100 |
+
dataset = ChatDataset()
|
101 |
+
train_loader = DataLoader(
|
102 |
+
dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0
|
103 |
+
)
|
104 |
+
|
105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
106 |
+
|
107 |
+
model = NeuralNet(input_size, hidden_size, output_size).to(device)
|
108 |
+
|
109 |
+
# Loss and optimizer
|
110 |
+
criterion = nn.CrossEntropyLoss()
|
111 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
112 |
+
|
113 |
+
# Train the model
|
114 |
+
for epoch in range(num_epochs):
|
115 |
+
for (words, labels) in train_loader:
|
116 |
+
words = words.to(device)
|
117 |
+
labels = labels.to(dtype=torch.long).to(device)
|
118 |
+
|
119 |
+
# Forward pass
|
120 |
+
outputs = model(words)
|
121 |
+
# if y would be one-hot, we must apply
|
122 |
+
# labels = torch.max(labels, 1)[1]
|
123 |
+
loss = criterion(outputs, labels)
|
124 |
+
|
125 |
+
# Backward and optimize
|
126 |
+
optimizer.zero_grad()
|
127 |
+
loss.backward()
|
128 |
+
optimizer.step()
|
129 |
+
|
130 |
+
if (epoch + 1) % 100 == 0:
|
131 |
+
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
|
132 |
+
|
133 |
+
|
134 |
+
print(f"final loss: {loss.item():.4f}")
|
135 |
+
|
136 |
+
data = {
|
137 |
+
"model_state": model.state_dict(),
|
138 |
+
"input_size": input_size,
|
139 |
+
"hidden_size": hidden_size,
|
140 |
+
"output_size": output_size,
|
141 |
+
"all_words": all_words,
|
142 |
+
"tags": tags,
|
143 |
+
}
|
144 |
+
|
145 |
+
FILE = "data.pth"
|
146 |
+
torch.save(data, FILE)
|
147 |
+
|
148 |
+
print(f"training complete. file saved to {FILE}")
|