Spaces:
Runtime error
Runtime error
kenichiro
commited on
Commit
·
926183f
1
Parent(s):
e96863e
Add application file
Browse files- README.md +65 -11
- __pycache__/chat.cpython-38.pyc +0 -0
- app.py +19 -0
- chat.py +66 -0
- data.pth +3 -0
- index2word.pickle +3 -0
- intents.json +84 -0
- model.pickle +3 -0
- nltk_utils.py +27 -0
- run_segbot.py +106 -0
- solver.py +445 -0
- static/app.js +91 -0
- static/images/chatbox-icon.svg +3 -0
- static/style.css +200 -0
- templates/base.html +42 -0
README.md
CHANGED
@@ -1,13 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
title: Clinical Segnemt
|
3 |
-
emoji: 🌖
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.17.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: cc-by-3.0
|
11 |
-
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NLP based Chatbot in PyTorch
|
2 |
+
<img src="https://miro.medium.com/max/1400/1*VqLvWcTKgVpv1idxII591A.jpeg" width="470" height="350">
|
3 |
+
|
4 |
+
|
5 |
+
## Simple chatbot implementation with PyTorch.
|
6 |
+
|
7 |
+
* The implementation should be easy to follow for beginners and provide a basic understanding of chatbots.
|
8 |
+
|
9 |
+
* The implementation is straightforward with a Feed Forward Neural net with 2 hidden layers.
|
10 |
+
|
11 |
+
* Customization for your own use case is super easy. Just modify intents.json with possible patterns and responses and re-run the training (see below for more info).
|
12 |
+
|
13 |
+
In [this article](https://medium.com/@mlvictoriamaslova/nlp-based-chatbot-in-pytorch-bonus-flask-and-javascript-deployment-474c4e59ceff) on Medium I explain some NLP concepts that underlies building Chatbots.
|
14 |
+
|
15 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
|
20 |
+
### Create an environment
|
21 |
+
|
22 |
+
Whatever you prefer (e.g. conda or venv)
|
23 |
+
|
24 |
+
```
|
25 |
+
mkdir myproject
|
26 |
+
$ cd myproject
|
27 |
+
$ python3 -m venv venv
|
28 |
+
```
|
29 |
+
|
30 |
+
### Activate it
|
31 |
+
|
32 |
+
Mac / Linux:
|
33 |
+
```
|
34 |
+
. venv/bin/activate
|
35 |
+
```
|
36 |
+
Windows:
|
37 |
+
|
38 |
+
```
|
39 |
+
venv\Scripts\activate
|
40 |
+
```
|
41 |
+
|
42 |
+
### Install PyTorch and dependencies
|
43 |
+
|
44 |
+
For Installation of PyTorch see official website.
|
45 |
+
|
46 |
+
You also need nltk:
|
47 |
+
```
|
48 |
+
pip install nltk
|
49 |
+
```
|
50 |
+
If you get an error during the first run, you also need to install nltk.tokenize.punkt: Run this once in your terminal:
|
51 |
+
|
52 |
+
```
|
53 |
+
$ python
|
54 |
+
>>> import nltk
|
55 |
+
>>> nltk.download('punkt')
|
56 |
+
```
|
57 |
+
|
58 |
+
### Usage
|
59 |
+
|
60 |
+
Run
|
61 |
+
```
|
62 |
+
python train.py
|
63 |
+
```
|
64 |
+
This will dump data.pth file. And then run
|
65 |
+
```
|
66 |
+
python chat.py
|
67 |
+
```
|
__pycache__/chat.cpython-38.pyc
ADDED
Binary file (1.46 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
|
3 |
+
from chat import get_response
|
4 |
+
|
5 |
+
app = Flask(__name__)
|
6 |
+
|
7 |
+
@app.get("/")
|
8 |
+
def index_get():
|
9 |
+
return render_template("base.html")
|
10 |
+
|
11 |
+
@app.post("/predict")
|
12 |
+
def predict():
|
13 |
+
text = request.get_json().get("message")
|
14 |
+
response = get_response(text)
|
15 |
+
message = {"answer": response}
|
16 |
+
return jsonify(message)
|
17 |
+
|
18 |
+
if __name__=="__main__":
|
19 |
+
app.run(debug=True)
|
chat.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from nltk_utils import bag_of_words, tokenize
|
7 |
+
from run_segbot import get_model
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
with open('intents.json', 'r') as json_data:
|
12 |
+
intents = json.load(json_data)
|
13 |
+
|
14 |
+
#FILE = "data.pth"
|
15 |
+
#data = torch.load(FILE)
|
16 |
+
|
17 |
+
#input_size = data["input_size"]
|
18 |
+
#hidden_size = data["hidden_size"]
|
19 |
+
#output_size = data["output_size"]
|
20 |
+
#all_words = data['all_words']
|
21 |
+
#tags = data['tags']
|
22 |
+
#model_state = data["model_state"]
|
23 |
+
|
24 |
+
#model = NeuralNet(input_size, hidden_size, output_size).to(device)
|
25 |
+
#model.load_state_dict(model_state)
|
26 |
+
#with open('model.pickle', 'rb') as f:
|
27 |
+
# model = pickle.load(f)
|
28 |
+
|
29 |
+
model = get_model()
|
30 |
+
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
bot_name = "Sam"
|
34 |
+
|
35 |
+
|
36 |
+
def get_response(msg):
|
37 |
+
sentence = tokenize(msg)
|
38 |
+
X = bag_of_words(sentence, all_words)
|
39 |
+
X = X.reshape(1, X.shape[0])
|
40 |
+
X = torch.from_numpy(X).to(device)
|
41 |
+
|
42 |
+
output = model(X)
|
43 |
+
_, predicted = torch.max(output, dim=1)
|
44 |
+
|
45 |
+
tag = tags[predicted.item()]
|
46 |
+
|
47 |
+
probs = torch.softmax(output, dim=1)
|
48 |
+
prob = probs[0][predicted.item()]
|
49 |
+
if prob.item() > 0.75:
|
50 |
+
for intent in intents['intents']:
|
51 |
+
if tag == intent["tag"]:
|
52 |
+
return random.choice(intent['responses'])
|
53 |
+
|
54 |
+
return "I do not understand..."
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
print("Let's chat! (type 'quit' to exit)")
|
59 |
+
while True:
|
60 |
+
# sentence = "do you use credit cards?"
|
61 |
+
sentence = input("You: ")
|
62 |
+
if sentence == "quit":
|
63 |
+
break
|
64 |
+
|
65 |
+
resp = get_response(sentence)
|
66 |
+
print(resp)
|
data.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f20bb4bda5d1517c4bb6d201139d136b0840d48cda09237e92bbb5b0b1fd63f4
|
3 |
+
size 5015
|
index2word.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75789974bed3cd0bc31ad888f26cf977a1c14fb35bc504849fa066cab1f845dd
|
3 |
+
size 47914175
|
intents.json
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"intents": [
|
3 |
+
{
|
4 |
+
"tag": "greeting",
|
5 |
+
"patterns": [
|
6 |
+
"Hi",
|
7 |
+
"Hey",
|
8 |
+
"How are you",
|
9 |
+
"Is anyone there?",
|
10 |
+
"Hello",
|
11 |
+
"Good day"
|
12 |
+
],
|
13 |
+
"responses": [
|
14 |
+
"Hey :-)",
|
15 |
+
"Hello, thanks for visiting",
|
16 |
+
"Hi there, what can I do for you?",
|
17 |
+
"Hi there, how can I help?"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"tag": "goodbye",
|
22 |
+
"patterns": ["Bye", "See you later", "Goodbye"],
|
23 |
+
"responses": [
|
24 |
+
"See you later, thanks for visiting",
|
25 |
+
"Have a nice day",
|
26 |
+
"Bye! Come back again soon."
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"tag": "thanks",
|
31 |
+
"patterns": ["Thanks", "Thank you", "That's helpful", "Thank's a lot!"],
|
32 |
+
"responses": ["Happy to help!", "Any time!", "My pleasure"]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"tag": "items",
|
36 |
+
"patterns": [
|
37 |
+
"Which items do you have?",
|
38 |
+
"What kinds of items are there?",
|
39 |
+
"What do you sell?"
|
40 |
+
],
|
41 |
+
"responses": [
|
42 |
+
"We sell coffee and tea",
|
43 |
+
"We have coffee and tea"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"tag": "payments",
|
48 |
+
"patterns": [
|
49 |
+
"Do you take credit cards?",
|
50 |
+
"Do you accept Mastercard?",
|
51 |
+
"Can I pay with Paypal?",
|
52 |
+
"Are you cash only?"
|
53 |
+
],
|
54 |
+
"responses": [
|
55 |
+
"We accept VISA, Mastercard and Paypal",
|
56 |
+
"We accept most major credit cards, and Paypal"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"tag": "delivery",
|
61 |
+
"patterns": [
|
62 |
+
"How long does delivery take?",
|
63 |
+
"How long does shipping take?",
|
64 |
+
"When do I get my delivery?"
|
65 |
+
],
|
66 |
+
"responses": [
|
67 |
+
"Delivery takes 2-4 days",
|
68 |
+
"Shipping takes 2-4 days"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"tag": "funny",
|
73 |
+
"patterns": [
|
74 |
+
"Tell me a joke!",
|
75 |
+
"Tell me something funny!",
|
76 |
+
"Do you know a joke?"
|
77 |
+
],
|
78 |
+
"responses": [
|
79 |
+
"Why did the hipster burn his mouth? He drank the coffee before it was cool.",
|
80 |
+
"What did the buffalo say when his son left for college? Bison."
|
81 |
+
]
|
82 |
+
}
|
83 |
+
]
|
84 |
+
}
|
model.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc13a6fa988683240ebd80f50a53a864fc5c9b6ad90c0e3d72c624749b542a9d
|
3 |
+
size 4948315605
|
nltk_utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import numpy as np
|
3 |
+
#nltk.download('all')
|
4 |
+
from nltk.stem.porter import PorterStemmer
|
5 |
+
stemmer = PorterStemmer()
|
6 |
+
def tokenize(sentence):
|
7 |
+
"""
|
8 |
+
split sentence into array of words/tokens
|
9 |
+
a token can be a word or punctuation character, or number
|
10 |
+
"""
|
11 |
+
return nltk.word_tokenize(sentence)
|
12 |
+
|
13 |
+
def stem(word):
|
14 |
+
|
15 |
+
return stemmer.stem(word.lower())
|
16 |
+
|
17 |
+
def bag_of_words(tokenized_sentence, all_words):
|
18 |
+
|
19 |
+
tokenized_sentence = [stem(w) for w in tokenized_sentence]
|
20 |
+
|
21 |
+
bag = np.zeros(len(all_words), dtype=np.float32)
|
22 |
+
for idx, w in enumerate(all_words):
|
23 |
+
if w in tokenized_sentence:
|
24 |
+
bag[idx] = 1.0
|
25 |
+
return bag
|
26 |
+
|
27 |
+
|
run_segbot.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from nltk.tokenize import word_tokenize
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
from solver import TrainSolver
|
8 |
+
|
9 |
+
from model import PointerNetworks
|
10 |
+
import gensim
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
class Lang:
|
14 |
+
def __init__(self, name):
|
15 |
+
self.name = name
|
16 |
+
self.word2index = {"RE_DIGITS":1,"UNKNOWN":0,"PADDING":2000001}
|
17 |
+
self.word2count = {"RE_DIGITS":1,"UNKNOWN":1,"PADDING":1}
|
18 |
+
self.index2word = {2000001: "PADDING", 1: "RE_DIGITS", 0: "UNKNOWN"}
|
19 |
+
self.n_words = 3 # Count SOS and EOS
|
20 |
+
|
21 |
+
def addSentence(self, sentence):
|
22 |
+
for word in sentence.strip('\n').strip('\r').split(' '):
|
23 |
+
self.addWord(word)
|
24 |
+
|
25 |
+
def addWord(self, word):
|
26 |
+
if word not in self.word2index:
|
27 |
+
self.word2index[word] = self.n_words
|
28 |
+
self.word2count[word] = 1
|
29 |
+
self.index2word[self.n_words] = word
|
30 |
+
self.n_words += 1
|
31 |
+
else:
|
32 |
+
self.word2count[word] += 1
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def mytokenizer(inS,all_dict):
|
37 |
+
|
38 |
+
#repDig = re.sub(r'\d+[\.,/]?\d+','RE_DIGITS',inS)
|
39 |
+
#repDig = re.sub(r'\d*[\d,]*\d+', 'RE_DIGITS', inS)
|
40 |
+
toked = inS
|
41 |
+
or_toked = inS
|
42 |
+
re_unk_list = []
|
43 |
+
ori_list = []
|
44 |
+
|
45 |
+
for (i,t) in enumerate(toked):
|
46 |
+
if t not in all_dict and t not in ['RE_DIGITS']:
|
47 |
+
re_unk_list.append('UNKNOWN')
|
48 |
+
ori_list.append(or_toked[i])
|
49 |
+
else:
|
50 |
+
re_unk_list.append(t)
|
51 |
+
ori_list.append(or_toked[i])
|
52 |
+
|
53 |
+
labey_edus = [0]*len(re_unk_list)
|
54 |
+
labey_edus[-1] = 1
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
return ori_list,re_unk_list,labey_edus
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def get_mapping(X,Y,D):
|
64 |
+
|
65 |
+
X_map = []
|
66 |
+
for w in X:
|
67 |
+
if w in D:
|
68 |
+
X_map.append(D[w])
|
69 |
+
else:
|
70 |
+
X_map.append(D['UNKNOWN'])
|
71 |
+
|
72 |
+
X_map = np.array([X_map])
|
73 |
+
Y_map = np.array([Y])
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
return X_map,Y_map
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def get_model():
|
84 |
+
with open('model.pickle', 'rb') as f:
|
85 |
+
mysolver = pickle.load(f)
|
86 |
+
return mysolver
|
87 |
+
|
88 |
+
#for i in tqdm(range(0,26431)):
|
89 |
+
test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,index2word, fukugen)
|
90 |
+
#test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
|
91 |
+
#with open(str(i)+"seped","w")as f:
|
92 |
+
# f.write(o)
|
93 |
+
#test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
|
94 |
+
print(test_pre, test_rec, test_f1)
|
95 |
+
#start_b = visdata[3][0]
|
96 |
+
#end_b = visdata[2][0] + 1
|
97 |
+
#segments = []
|
98 |
+
|
99 |
+
#for i, END in enumerate(end_b):
|
100 |
+
# START = start_b[i]
|
101 |
+
# segments.append(' '.join(ori_X[START:END]))
|
102 |
+
|
103 |
+
return test_pre, test_rec, test_f1
|
104 |
+
|
105 |
+
|
106 |
+
|
solver.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.optim as optim
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
import random
|
7 |
+
from torch.nn.utils import clip_grad_norm
|
8 |
+
import copy
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import os
|
12 |
+
import pickle
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def get_decoder_index_XY(batchY):
|
17 |
+
'''
|
18 |
+
|
19 |
+
:param batchY: like [0 0 1 0 0 0 0 1]
|
20 |
+
:return:
|
21 |
+
'''
|
22 |
+
|
23 |
+
|
24 |
+
returnX =[]
|
25 |
+
returnY =[]
|
26 |
+
for i in range(len(batchY)):
|
27 |
+
|
28 |
+
curY = batchY[i]
|
29 |
+
index_1 = np.where(curY==1)
|
30 |
+
|
31 |
+
decoderY = index_1[0]
|
32 |
+
|
33 |
+
if len(index_1[0]) ==1:
|
34 |
+
decoderX = np.array([0])
|
35 |
+
else:
|
36 |
+
decoderX = np.append([0],decoderY[0:-1]+1)
|
37 |
+
returnX.append(decoderX)
|
38 |
+
returnY.append(decoderY)
|
39 |
+
|
40 |
+
returnX = np.array(returnX)
|
41 |
+
returnY = np.array(returnY)
|
42 |
+
|
43 |
+
return returnX,returnY
|
44 |
+
|
45 |
+
def align_variable_numpy(X,maxL,paddingNumber):
|
46 |
+
|
47 |
+
aligned = []
|
48 |
+
for cur in X:
|
49 |
+
ext_cur = []
|
50 |
+
ext_cur.extend(cur)
|
51 |
+
ext_cur.extend([paddingNumber] * (maxL - len(cur)))
|
52 |
+
aligned.append(ext_cur)
|
53 |
+
aligned = np.array(aligned)
|
54 |
+
|
55 |
+
return aligned
|
56 |
+
|
57 |
+
|
58 |
+
def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
|
59 |
+
|
60 |
+
|
61 |
+
if batch_size != None:
|
62 |
+
select_index = random.sample(range(len(numpyY)), batch_size)
|
63 |
+
else:
|
64 |
+
select_index = np.array(range(len(numpyY)))
|
65 |
+
|
66 |
+
select_index = np.array(range(len(numpyX)))
|
67 |
+
|
68 |
+
batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
|
69 |
+
batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
|
70 |
+
|
71 |
+
#print(batch_y)
|
72 |
+
index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
|
73 |
+
#index_decoder = [get_decoder_index_XY(i) for i in batch_y]
|
74 |
+
#index_decoder_X = [i[0] for i in index_decoder]
|
75 |
+
#index_decoder_Y = [i[1] for i in index_decoder]
|
76 |
+
#print(index_decoder_Y)
|
77 |
+
|
78 |
+
|
79 |
+
#all_lens = []
|
80 |
+
all_lens = np.array([len(x) for x in batch_y])
|
81 |
+
#for x in batch_y:
|
82 |
+
# print(x)
|
83 |
+
# try:
|
84 |
+
# all_lens.append(len(x))
|
85 |
+
# except:
|
86 |
+
# all_lens.append(1)
|
87 |
+
#all_lens = np.array(all_lens)
|
88 |
+
|
89 |
+
maxL = np.max(all_lens)
|
90 |
+
|
91 |
+
#idx = all_lens
|
92 |
+
#print(idx)
|
93 |
+
idx = np.argsort(all_lens)
|
94 |
+
idx = np.sort(idx)
|
95 |
+
#print(idx)
|
96 |
+
#idx = idx[::-1] # decreasing
|
97 |
+
#print(idx)
|
98 |
+
batch_x = [batch_x[i] for i in idx]
|
99 |
+
batch_y = [batch_y[i] for i in idx]
|
100 |
+
all_lens = all_lens[idx]
|
101 |
+
|
102 |
+
index_decoder_X = np.array([index_decoder_X[i] for i in idx])
|
103 |
+
index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
|
104 |
+
#print(index_decoder_Y)
|
105 |
+
|
106 |
+
numpy_batch_x = batch_x
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
batch_x = align_variable_numpy(batch_x,maxL,2000001)
|
111 |
+
batch_y = align_variable_numpy(batch_y,maxL,2)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
print(len(batch_x))
|
120 |
+
#batch_x = Variable(torch.from_numpy(batch_x.astype(np.int64)))
|
121 |
+
batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
|
122 |
+
|
123 |
+
|
124 |
+
if use_cuda:
|
125 |
+
batch_x = batch_x.cuda()
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class TrainSolver(object):
|
135 |
+
def __init__(self, model,train_x,train_y,dev_x,dev_y,save_path,batch_size,eval_size,epoch, lr,lr_decay_epoch,weight_decay,use_cuda):
|
136 |
+
|
137 |
+
self.lr = lr
|
138 |
+
self.model = model
|
139 |
+
self.epoch = epoch
|
140 |
+
self.train_x = train_x
|
141 |
+
self.train_y = train_y
|
142 |
+
self.use_cuda = use_cuda
|
143 |
+
self.batch_size = batch_size
|
144 |
+
self.lr_decay_epoch = lr_decay_epoch
|
145 |
+
self.eval_size = eval_size
|
146 |
+
|
147 |
+
|
148 |
+
self.dev_x, self.dev_y = dev_x, dev_y
|
149 |
+
|
150 |
+
self.model = model
|
151 |
+
self.save_path = save_path
|
152 |
+
self.weight_decay =weight_decay
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def sample_dev(self):
|
158 |
+
test_tr_x = []
|
159 |
+
test_tr_y = []
|
160 |
+
select_index = random.sample(range(len(self.train_y)),self.eval_size)
|
161 |
+
test_tr_x = [self.train_x[n] for n in select_index]
|
162 |
+
test_tr_y = [self.train_y[n] for n in select_index]
|
163 |
+
|
164 |
+
return test_tr_x,test_tr_y
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
|
173 |
+
|
174 |
+
tokendic = {}
|
175 |
+
#with open('index2word.pickle', 'rb') as f:
|
176 |
+
# index2word = pickle.load(f)
|
177 |
+
for n,i in enumerate(index2word):
|
178 |
+
tokendic[n] = i
|
179 |
+
All_C = []
|
180 |
+
All_R = []
|
181 |
+
All_G = []
|
182 |
+
"""
|
183 |
+
for i,cur_seq_y in enumerate(zip(ground_b,fukugen[nloop])):
|
184 |
+
#print(fukugen[nloop])
|
185 |
+
fuku = cur_seq_y[1]
|
186 |
+
cur_seq_y = cur_seq_y[0]
|
187 |
+
index_of_1 = np.where(cur_seq_y==1)[0]
|
188 |
+
#print(index_of_1)
|
189 |
+
index_pre = pre_b[i]
|
190 |
+
inp = x[i]
|
191 |
+
#print(len(inp))
|
192 |
+
"""
|
193 |
+
print(len(pre_b), len(ground_b), len(fukugen))
|
194 |
+
#global leng
|
195 |
+
#print(fukugen)
|
196 |
+
for i,cur_seq_y in enumerate(ground_b):
|
197 |
+
#print(fukugen[nloop])
|
198 |
+
fuku = fukugen[i]
|
199 |
+
#cur_seq_y = cur_seq_y[0]
|
200 |
+
index_of_1 = np.where(cur_seq_y==1)[0]
|
201 |
+
#print(index_of_1)
|
202 |
+
index_pre = pre_b[i]
|
203 |
+
inp = x[i]
|
204 |
+
#print(len(inp))
|
205 |
+
|
206 |
+
index_pre = np.array(index_pre)
|
207 |
+
END_B = index_of_1[-1]
|
208 |
+
index_pre = index_pre[index_pre != END_B]
|
209 |
+
index_of_1 = index_of_1[index_of_1 != END_B]
|
210 |
+
|
211 |
+
no_correct = len(np.intersect1d(list(index_of_1), list(index_pre)))
|
212 |
+
All_C.append(no_correct)
|
213 |
+
All_R.append(len(index_pre))
|
214 |
+
All_G.append(len(index_of_1))
|
215 |
+
|
216 |
+
index_of_1 = list(index_of_1)
|
217 |
+
index_pre = list(index_pre)
|
218 |
+
|
219 |
+
FN = []
|
220 |
+
FP = []
|
221 |
+
TP = []
|
222 |
+
sent = []
|
223 |
+
ex = ""
|
224 |
+
for j in inp:
|
225 |
+
sent.append(tokendic[int(j.to('cpu').detach().numpy().copy())])
|
226 |
+
for k in index_of_1:
|
227 |
+
if k not in index_pre:
|
228 |
+
FN.append(k)
|
229 |
+
if k in index_pre:
|
230 |
+
TP.append(k)
|
231 |
+
for k in index_pre:
|
232 |
+
if k not in index_of_1:
|
233 |
+
FP.append(k)
|
234 |
+
#if len(FN) == 0 and len(FP) == 0:
|
235 |
+
# continue
|
236 |
+
#for n,i in enumerate(sent):
|
237 |
+
for n,k in enumerate(zip(sent, fuku)):
|
238 |
+
f = k[1]
|
239 |
+
i = k[0]
|
240 |
+
if k == "<pad>":
|
241 |
+
continue
|
242 |
+
if n in FP:
|
243 |
+
ex += f + "<FP>"
|
244 |
+
else:
|
245 |
+
ex += f
|
246 |
+
"""
|
247 |
+
if n in FN:
|
248 |
+
#ex += i + "<FN>"
|
249 |
+
ex += i
|
250 |
+
elif n in FP:
|
251 |
+
ex += i + "<FP>"
|
252 |
+
elif n in TP:
|
253 |
+
ex += i + "<TP>"
|
254 |
+
else:
|
255 |
+
ex += i
|
256 |
+
"""
|
257 |
+
#with open(str(nloop)+"_sep_nounk.txt", "a")as f:
|
258 |
+
# f.write(ex+"\n")
|
259 |
+
#print(i)
|
260 |
+
#leng += 1
|
261 |
+
|
262 |
+
return All_C,All_R,All_G
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
|
268 |
+
def get_batch_metric(self,pre_b, ground_b):
|
269 |
+
|
270 |
+
b_pr =[]
|
271 |
+
b_re =[]
|
272 |
+
b_f1 =[]
|
273 |
+
for i,cur_seq_y in enumerate(ground_b):
|
274 |
+
index_of_1 = np.where(cur_seq_y==1)[0]
|
275 |
+
index_pre = pre_b[i]
|
276 |
+
|
277 |
+
no_correct = len(np.intersect1d(index_of_1,index_pre))
|
278 |
+
|
279 |
+
cur_pre = no_correct / len(index_pre)
|
280 |
+
cur_rec = no_correct / len(index_of_1)
|
281 |
+
cur_f1 = 2*cur_pre*cur_rec/ (cur_pre+cur_rec)
|
282 |
+
|
283 |
+
b_pr.append(cur_pre)
|
284 |
+
b_re.append(cur_rec)
|
285 |
+
b_f1.append(cur_f1)
|
286 |
+
|
287 |
+
return b_pr,b_re,b_f1
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
|
292 |
+
for nloop in tqdm(range(0,108)):
|
293 |
+
dataY = data2Y[nloop]
|
294 |
+
dataX = data2X[nloop]
|
295 |
+
fukugen = fukugen2[nloop]
|
296 |
+
#print(len(dataX), len(dataY), len(fukugen))
|
297 |
+
need_loop = int(np.ceil(len(dataY) / self.batch_size))
|
298 |
+
#need_loop = int(np.ceil(len(dataY) / 1))
|
299 |
+
all_ave_loss =[]
|
300 |
+
all_boundary =[]
|
301 |
+
all_boundary_start = []
|
302 |
+
all_align_matrix = []
|
303 |
+
all_index_decoder_y =[]
|
304 |
+
all_x_save = []
|
305 |
+
|
306 |
+
all_C =[]
|
307 |
+
all_R =[]
|
308 |
+
all_G =[]
|
309 |
+
|
310 |
+
for lp in range(need_loop):
|
311 |
+
startN = lp*self.batch_size
|
312 |
+
endN = (lp+1)*self.batch_size
|
313 |
+
if endN > len(dataY):
|
314 |
+
endN = len(dataY)
|
315 |
+
#print(fukugen)
|
316 |
+
fukuge = fukugen[startN:endN]
|
317 |
+
#print(startN, endN)
|
318 |
+
#print(len(fukugen))
|
319 |
+
#print(fukugen)
|
320 |
+
#for nloop in tqdm(range(0,26431)):
|
321 |
+
numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
322 |
+
dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
|
323 |
+
#numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
324 |
+
# dataX, dataY, None, self.use_cuda)
|
325 |
+
|
326 |
+
batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,
|
327 |
+
index_decoder_Y,
|
328 |
+
all_lens)
|
329 |
+
|
330 |
+
all_ave_loss.extend([batch_ave_loss.data.item()]) #[batch_ave_loss.data[0]]
|
331 |
+
all_boundary.extend(batch_boundary)
|
332 |
+
all_boundary_start.extend(batch_boundary_start)
|
333 |
+
all_align_matrix.extend(batch_align_matrix)
|
334 |
+
all_index_decoder_y.extend(index_decoder_Y)
|
335 |
+
all_x_save.extend(numpy_batch_x)
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
#print(batch_y)
|
340 |
+
ba_C,ba_R,ba_G = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
|
341 |
+
|
342 |
+
all_C.extend(ba_C)
|
343 |
+
all_R.extend(ba_R)
|
344 |
+
all_G.extend(ba_G)
|
345 |
+
|
346 |
+
|
347 |
+
ba_pre = np.sum(all_C)/ np.sum(all_R)
|
348 |
+
ba_rec = np.sum(all_C)/ np.sum(all_G)
|
349 |
+
ba_f1 = 2*ba_pre*ba_rec/ (ba_pre+ba_rec)
|
350 |
+
|
351 |
+
|
352 |
+
return np.mean(all_ave_loss),ba_pre,ba_rec,ba_f1, (all_x_save,all_index_decoder_y,all_boundary, all_boundary_start, all_align_matrix)
|
353 |
+
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
def adjust_learning_rate(self,optimizer,epoch,lr_decay=0.5, lr_decay_epoch=5):
|
361 |
+
|
362 |
+
if (epoch % lr_decay_epoch == 0) and (epoch != 0):
|
363 |
+
for param_group in optimizer.param_groups:
|
364 |
+
param_group['lr'] *= lr_decay
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
def train(self,n):
|
369 |
+
|
370 |
+
self.test_train_x, self.test_train_y = self.sample_dev()
|
371 |
+
|
372 |
+
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, weight_decay=self.weight_decay)
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
num_each_batch = int(np.round(len(self.train_y) / self.batch_size))
|
377 |
+
|
378 |
+
#os.mkdir(self.save_path)
|
379 |
+
|
380 |
+
best_i =0
|
381 |
+
best_f1 =0
|
382 |
+
|
383 |
+
for epoch in range(self.epoch):
|
384 |
+
print(epoch)
|
385 |
+
self.adjust_learning_rate(optimizer, epoch, 0.8, self.lr_decay_epoch)
|
386 |
+
|
387 |
+
track_epoch_loss = []
|
388 |
+
for iter in tqdm(range(num_each_batch)):
|
389 |
+
#print("epoch:%d,iteration:%d" % (epoch, iter))
|
390 |
+
|
391 |
+
self.model.zero_grad()
|
392 |
+
|
393 |
+
numpy_batch_x,batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
|
394 |
+
self.train_x, self.train_y, self.batch_size, self.use_cuda)
|
395 |
+
|
396 |
+
neg_loss = self.model.neg_log_likelihood(batch_x, index_decoder_X, index_decoder_Y,all_lens)
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
neg_loss_v = float(neg_loss.data.item())
|
401 |
+
#print(neg_loss_v)
|
402 |
+
track_epoch_loss.append(neg_loss_v)
|
403 |
+
|
404 |
+
neg_loss.backward()
|
405 |
+
|
406 |
+
clip_grad_norm(self.model.parameters(), 5)
|
407 |
+
optimizer.step()
|
408 |
+
|
409 |
+
|
410 |
+
#TODO: after each epoch,check accuracy
|
411 |
+
|
412 |
+
|
413 |
+
self.model.eval()
|
414 |
+
|
415 |
+
#tr_batch_ave_loss, tr_pre, tr_rec, tr_f1 ,visdata= self.check_accuracy(self.test_train_x,self.test_train_y)
|
416 |
+
|
417 |
+
dev_batch_ave_loss, dev_pre, dev_rec, dev_f1, visdata =self.check_accuracy(self.dev_x,self.dev_y,n)
|
418 |
+
print("f1="+str(dev_f1))
|
419 |
+
print("loss="+str(dev_batch_ave_loss))
|
420 |
+
"""
|
421 |
+
if best_f1 < dev_f1:
|
422 |
+
best_f1 = dev_f1
|
423 |
+
best_rec = dev_rec
|
424 |
+
best_pre = dev_pre
|
425 |
+
best_i = epoch
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
save_data = [epoch,dev_batch_ave_loss,dev_pre,dev_rec,dev_f1]
|
430 |
+
|
431 |
+
|
432 |
+
save_file_name = 'bs_{}_es_{}_lr_{}_lrdc_{}_wd_{}_epoch_loss_acc_pk_wd.txt'.format(self.batch_size,self.eval_size,self.lr,self.lr_decay_epoch,self.weight_decay)
|
433 |
+
"""
|
434 |
+
#with open(os.path.join(self.save_path,save_file_name), 'a') as f:
|
435 |
+
# f.write(','.join(map(str,save_data))+'\n')
|
436 |
+
|
437 |
+
|
438 |
+
#if epoch % 1 ==0 and epoch !=0:
|
439 |
+
# torch.save(self.model, os.path.join(self.save_path,r'model_epoch_%d.torchsave'%(epoch)))
|
440 |
+
|
441 |
+
|
442 |
+
self.model.train()
|
443 |
+
|
444 |
+
#return best_i,best_pre,best_rec,best_f1
|
445 |
+
return best_i,best_f1,n
|
static/app.js
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Chatbox {
|
2 |
+
constructor() {
|
3 |
+
this.args = {
|
4 |
+
openButton: document.querySelector('.chatbox__button'),
|
5 |
+
chatBox: document.querySelector('.chatbox__support'),
|
6 |
+
sendButton: document.querySelector('.send__button')
|
7 |
+
}
|
8 |
+
|
9 |
+
this.state = false;
|
10 |
+
this.messages = [];
|
11 |
+
}
|
12 |
+
|
13 |
+
display() {
|
14 |
+
const {openButton, chatBox, sendButton} = this.args;
|
15 |
+
|
16 |
+
openButton.addEventListener('click', () => this.toggleState(chatBox))
|
17 |
+
|
18 |
+
sendButton.addEventListener('click', () => this.onSendButton(chatBox))
|
19 |
+
|
20 |
+
const node = chatBox.querySelector('input');
|
21 |
+
node.addEventListener("keyup", ({key}) => {
|
22 |
+
if (key === "Enter") {
|
23 |
+
this.onSendButton(chatBox)
|
24 |
+
}
|
25 |
+
})
|
26 |
+
}
|
27 |
+
|
28 |
+
toggleState(chatbox) {
|
29 |
+
this.state = !this.state;
|
30 |
+
|
31 |
+
// show or hides the box
|
32 |
+
if(this.state) {
|
33 |
+
chatbox.classList.add('chatbox--active')
|
34 |
+
} else {
|
35 |
+
chatbox.classList.remove('chatbox--active')
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
onSendButton(chatbox) {
|
40 |
+
var textField = chatbox.querySelector('input');
|
41 |
+
let text1 = textField.value
|
42 |
+
if (text1 === "") {
|
43 |
+
return;
|
44 |
+
}
|
45 |
+
|
46 |
+
let msg1 = { name: "User", message: text1 }
|
47 |
+
this.messages.push(msg1);
|
48 |
+
|
49 |
+
fetch('http://127.0.0.1:5000/predict', {
|
50 |
+
method: 'POST',
|
51 |
+
body: JSON.stringify({ message: text1 }),
|
52 |
+
mode: 'cors',
|
53 |
+
headers: {
|
54 |
+
'Content-Type': 'application/json'
|
55 |
+
},
|
56 |
+
})
|
57 |
+
.then(r => r.json())
|
58 |
+
.then(r => {
|
59 |
+
let msg2 = { name: "Sam", message: r.answer };
|
60 |
+
this.messages.push(msg2);
|
61 |
+
this.updateChatText(chatbox)
|
62 |
+
textField.value = ''
|
63 |
+
|
64 |
+
}).catch((error) => {
|
65 |
+
console.error('Error:', error);
|
66 |
+
this.updateChatText(chatbox)
|
67 |
+
textField.value = ''
|
68 |
+
});
|
69 |
+
}
|
70 |
+
|
71 |
+
updateChatText(chatbox) {
|
72 |
+
var html = '';
|
73 |
+
this.messages.slice().reverse().forEach(function(item, index) {
|
74 |
+
if (item.name === "Sam")
|
75 |
+
{
|
76 |
+
html += '<div class="messages__item messages__item--visitor">' + item.message + '</div>'
|
77 |
+
}
|
78 |
+
else
|
79 |
+
{
|
80 |
+
html += '<div class="messages__item messages__item--operator">' + item.message + '</div>'
|
81 |
+
}
|
82 |
+
});
|
83 |
+
|
84 |
+
const chatmessage = chatbox.querySelector('.chatbox__messages');
|
85 |
+
chatmessage.innerHTML = html;
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
const chatbox = new Chatbox();
|
91 |
+
chatbox.display();
|
static/images/chatbox-icon.svg
ADDED
|
static/style.css
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
* {
|
2 |
+
box-sizing: border-box;
|
3 |
+
margin: 0;
|
4 |
+
padding: 0;
|
5 |
+
}
|
6 |
+
|
7 |
+
body {
|
8 |
+
font-family: 'Nunito', sans-serif;
|
9 |
+
font-weight: 400;
|
10 |
+
font-size: 100%;
|
11 |
+
background: #F1F1F1;
|
12 |
+
}
|
13 |
+
|
14 |
+
*, html {
|
15 |
+
--primaryGradient: linear-gradient(93.12deg, #581B98 0.52%, #9C1DE7 100%);
|
16 |
+
--secondaryGradient: linear-gradient(268.91deg, #581B98 -2.14%, #9C1DE7 99.69%);
|
17 |
+
--primaryBoxShadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
|
18 |
+
--secondaryBoxShadow: 0px -10px 15px rgba(0, 0, 0, 0.1);
|
19 |
+
--primary: #581B98;
|
20 |
+
}
|
21 |
+
|
22 |
+
/* CHATBOX
|
23 |
+
=============== */
|
24 |
+
.chatbox {
|
25 |
+
position: absolute;
|
26 |
+
bottom: 30px;
|
27 |
+
right: 30px;
|
28 |
+
}
|
29 |
+
|
30 |
+
/* CONTENT IS CLOSE */
|
31 |
+
.chatbox__support {
|
32 |
+
display: flex;
|
33 |
+
flex-direction: column;
|
34 |
+
background: #eee;
|
35 |
+
width: 300px;
|
36 |
+
height: 350px;
|
37 |
+
z-index: -123456;
|
38 |
+
opacity: 0;
|
39 |
+
transition: all .5s ease-in-out;
|
40 |
+
}
|
41 |
+
|
42 |
+
/* CONTENT ISOPEN */
|
43 |
+
.chatbox--active {
|
44 |
+
transform: translateY(-40px);
|
45 |
+
z-index: 123456;
|
46 |
+
opacity: 1;
|
47 |
+
|
48 |
+
}
|
49 |
+
|
50 |
+
/* BUTTON */
|
51 |
+
.chatbox__button {
|
52 |
+
text-align: right;
|
53 |
+
}
|
54 |
+
|
55 |
+
.send__button {
|
56 |
+
padding: 6px;
|
57 |
+
background: transparent;
|
58 |
+
border: none;
|
59 |
+
outline: none;
|
60 |
+
cursor: pointer;
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
/* HEADER */
|
65 |
+
.chatbox__header {
|
66 |
+
position: sticky;
|
67 |
+
top: 0;
|
68 |
+
background: orange;
|
69 |
+
}
|
70 |
+
|
71 |
+
/* MESSAGES */
|
72 |
+
.chatbox__messages {
|
73 |
+
margin-top: auto;
|
74 |
+
display: flex;
|
75 |
+
overflow-y: scroll;
|
76 |
+
flex-direction: column-reverse;
|
77 |
+
}
|
78 |
+
|
79 |
+
.messages__item {
|
80 |
+
background: orange;
|
81 |
+
max-width: 60.6%;
|
82 |
+
width: fit-content;
|
83 |
+
}
|
84 |
+
|
85 |
+
.messages__item--operator {
|
86 |
+
margin-left: auto;
|
87 |
+
}
|
88 |
+
|
89 |
+
.messages__item--visitor {
|
90 |
+
margin-right: auto;
|
91 |
+
}
|
92 |
+
|
93 |
+
/* FOOTER */
|
94 |
+
.chatbox__footer {
|
95 |
+
position: sticky;
|
96 |
+
bottom: 0;
|
97 |
+
}
|
98 |
+
|
99 |
+
.chatbox__support {
|
100 |
+
background: #f9f9f9;
|
101 |
+
height: 450px;
|
102 |
+
width: 350px;
|
103 |
+
box-shadow: 0px 0px 15px rgba(0, 0, 0, 0.1);
|
104 |
+
border-top-left-radius: 20px;
|
105 |
+
border-top-right-radius: 20px;
|
106 |
+
}
|
107 |
+
|
108 |
+
/* HEADER */
|
109 |
+
.chatbox__header {
|
110 |
+
background: var(--primaryGradient);
|
111 |
+
display: flex;
|
112 |
+
flex-direction: row;
|
113 |
+
align-items: center;
|
114 |
+
justify-content: center;
|
115 |
+
padding: 15px 20px;
|
116 |
+
border-top-left-radius: 20px;
|
117 |
+
border-top-right-radius: 20px;
|
118 |
+
box-shadow: var(--primaryBoxShadow);
|
119 |
+
}
|
120 |
+
|
121 |
+
.chatbox__image--header {
|
122 |
+
margin-right: 10px;
|
123 |
+
}
|
124 |
+
|
125 |
+
.chatbox__heading--header {
|
126 |
+
font-size: 1.2rem;
|
127 |
+
color: white;
|
128 |
+
}
|
129 |
+
|
130 |
+
.chatbox__description--header {
|
131 |
+
font-size: .9rem;
|
132 |
+
color: white;
|
133 |
+
}
|
134 |
+
|
135 |
+
/* Messages */
|
136 |
+
.chatbox__messages {
|
137 |
+
padding: 0 20px;
|
138 |
+
}
|
139 |
+
|
140 |
+
.messages__item {
|
141 |
+
margin-top: 10px;
|
142 |
+
background: #E0E0E0;
|
143 |
+
padding: 8px 12px;
|
144 |
+
max-width: 70%;
|
145 |
+
}
|
146 |
+
|
147 |
+
.messages__item--visitor,
|
148 |
+
.messages__item--typing {
|
149 |
+
border-top-left-radius: 20px;
|
150 |
+
border-top-right-radius: 20px;
|
151 |
+
border-bottom-right-radius: 20px;
|
152 |
+
}
|
153 |
+
|
154 |
+
.messages__item--operator {
|
155 |
+
border-top-left-radius: 20px;
|
156 |
+
border-top-right-radius: 20px;
|
157 |
+
border-bottom-left-radius: 20px;
|
158 |
+
background: var(--primary);
|
159 |
+
color: white;
|
160 |
+
}
|
161 |
+
|
162 |
+
/* FOOTER */
|
163 |
+
.chatbox__footer {
|
164 |
+
display: flex;
|
165 |
+
flex-direction: row;
|
166 |
+
align-items: center;
|
167 |
+
justify-content: space-between;
|
168 |
+
padding: 20px 20px;
|
169 |
+
background: var(--secondaryGradient);
|
170 |
+
box-shadow: var(--secondaryBoxShadow);
|
171 |
+
border-bottom-right-radius: 10px;
|
172 |
+
border-bottom-left-radius: 10px;
|
173 |
+
margin-top: 20px;
|
174 |
+
}
|
175 |
+
|
176 |
+
.chatbox__footer input {
|
177 |
+
width: 80%;
|
178 |
+
border: none;
|
179 |
+
padding: 10px 10px;
|
180 |
+
border-radius: 30px;
|
181 |
+
text-align: left;
|
182 |
+
}
|
183 |
+
|
184 |
+
.chatbox__send--footer {
|
185 |
+
color: white;
|
186 |
+
}
|
187 |
+
|
188 |
+
.chatbox__button button,
|
189 |
+
.chatbox__button button:focus,
|
190 |
+
.chatbox__button button:visited {
|
191 |
+
padding: 10px;
|
192 |
+
background: white;
|
193 |
+
border: none;
|
194 |
+
outline: none;
|
195 |
+
border-top-left-radius: 50px;
|
196 |
+
border-top-right-radius: 50px;
|
197 |
+
border-bottom-left-radius: 50px;
|
198 |
+
box-shadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
|
199 |
+
cursor: pointer;
|
200 |
+
}
|
templates/base.html
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
|
4 |
+
|
5 |
+
<head>
|
6 |
+
<meta charset="UTF-8">
|
7 |
+
<title>Chatbot</title>
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<div class="container">
|
11 |
+
<div class="chatbox">
|
12 |
+
<div class="chatbox__support">
|
13 |
+
<div class="chatbox__header">
|
14 |
+
<div class="chatbox__image--header">
|
15 |
+
<img src="https://img.icons8.com/color/48/000000/circled-user-female-skin-type-5--v1.png" alt="image">
|
16 |
+
</div>
|
17 |
+
<div class="chatbox__content--header">
|
18 |
+
<h4 class="chatbox__heading--header">Chat support</h4>
|
19 |
+
<p class="chatbox__description--header">Hi. My name is Sam. How can I help you?</p>
|
20 |
+
</div>
|
21 |
+
</div>
|
22 |
+
<div class="chatbox__messages">
|
23 |
+
<div></div>
|
24 |
+
</div>
|
25 |
+
<div class="chatbox__footer">
|
26 |
+
<input type="text" placeholder="Write a message...">
|
27 |
+
<button class="chatbox__send--footer send__button">Send</button>
|
28 |
+
</div>
|
29 |
+
</div>
|
30 |
+
<div class="chatbox__button">
|
31 |
+
<button><img src="{{ url_for('static', filename='images/chatbox-icon.svg') }}" /></button>
|
32 |
+
</div>
|
33 |
+
</div>
|
34 |
+
</div>
|
35 |
+
|
36 |
+
<script>
|
37 |
+
$SCRIPT_ROOT = {{ request.script_root|tojson }};
|
38 |
+
</script>
|
39 |
+
<script type="text/javascript" src="{{ url_for('static', filename='app.js') }}"></script>
|
40 |
+
|
41 |
+
</body>
|
42 |
+
</html>
|