bohraanuj23 commited on
Commit
e3f6cbe
·
verified ·
1 Parent(s): 6aad03c

Delete chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +0 -144
chatbot.py DELETED
@@ -1,144 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import numpy as np
4
- import nltk
5
- from nltk.stem.porter import PorterStemmer
6
- from torch.utils.data import Dataset, DataLoader
7
- import random
8
- import json
9
-
10
- nltk.download("punkt")
11
-
12
-
13
- class NeuralNet(torch.nn.Module):
14
- def __init__(self, input_size, hidden_size, num_classes):
15
- super(NeuralNet, self).__init__()
16
- self.l1 = torch.nn.Linear(input_size, hidden_size)
17
- self.l2 = torch.nn.Linear(hidden_size, hidden_size)
18
- self.l3 = torch.nn.Linear(hidden_size, num_classes)
19
- self.relu = torch.nn.ReLU()
20
-
21
- def forward(self, x):
22
- out = self.l1(x)
23
- out = self.relu(out)
24
- out = self.l2(out)
25
- out = self.relu(out)
26
- out = self.l3(out)
27
- return out
28
-
29
-
30
- class ChatDataset(Dataset):
31
- def __init__(self, X_train, y_train):
32
- self.n_samples = len(X_train)
33
- self.x_data = X_train
34
- self.y_data = y_train
35
-
36
- def __getitem__(self, index):
37
- return self.x_data[index], self.y_data[index]
38
-
39
- def __len__(self):
40
- return self.n_samples
41
-
42
-
43
- def tokenize(sentence):
44
- return nltk.word_tokenize(sentence)
45
-
46
-
47
- def stem(word):
48
- return stemmer.stem(word.lower())
49
-
50
-
51
- def bag_of_words(tokenized_sentence, words):
52
- sentence_words = [stem(word) for word in tokenized_sentence]
53
- bag = np.zeros(len(words), dtype=np.float32)
54
- for idx, w in enumerate(words):
55
- if w in sentence_words:
56
- bag[idx] = 1
57
- return bag
58
-
59
-
60
- # Load data
61
- intents_file_path = r"C:\Users\Anuj Bohra\Desktop\chatbot\data\intents.json"
62
-
63
- with open(intents_file_path, "r") as f:
64
- intents = json.load(f)
65
-
66
- stemmer = PorterStemmer()
67
-
68
- all_words = []
69
- tags = []
70
- xy = []
71
-
72
- for intent in intents["intents"]:
73
- tag = intent["tag"]
74
- tags.append(tag)
75
- for pattern in intent["patterns"]:
76
- w = tokenize(pattern)
77
- all_words.extend(w)
78
- xy.append((w, tag))
79
-
80
- ignore_words = ["?", ".", "!"]
81
- all_words = [stem(w) for w in all_words if w not in ignore_words]
82
- all_words = sorted(set(all_words))
83
- tags = sorted(set(tags))
84
-
85
- X_train = []
86
- y_train = []
87
-
88
- for pattern_sentence, tag in xy:
89
- bag = bag_of_words(pattern_sentence, all_words)
90
- X_train.append(bag)
91
- label = tags.index(tag)
92
- y_train.append(label)
93
-
94
- X_train = np.array(X_train)
95
- y_train = np.array(y_train)
96
-
97
- # Model parameters
98
- input_size = len(X_train[0])
99
- hidden_size = 8
100
- output_size = len(tags)
101
- learning_rate = 0.001
102
- num_epochs = 1000
103
- batch_size = 8
104
-
105
- # Initialize model, dataset, dataloader
106
- model = NeuralNet(input_size, hidden_size, output_size)
107
- criterion = torch.nn.CrossEntropyLoss()
108
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
109
- dataset = ChatDataset(X_train, y_train)
110
- train_loader = DataLoader(
111
- dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0
112
- )
113
-
114
- # Load PyTorch model
115
- checkpoint = torch.load("data.pth")
116
- model.load_state_dict(checkpoint["model_state"])
117
- model.eval()
118
-
119
- # Streamlit UI
120
- st.title("Medical ChatBot")
121
-
122
- user_input = st.text_input("You:", "")
123
-
124
- if st.button("Ask"):
125
- sentence = user_input
126
- sentence = tokenize(sentence)
127
- X = bag_of_words(sentence, all_words)
128
- X = X.reshape(1, X.shape[0])
129
- X = torch.from_numpy(X)
130
-
131
- output = model(X)
132
- _, predicted = torch.max(output, dim=1)
133
-
134
- tag = tags[predicted.item()]
135
-
136
- probs = torch.softmax(output, dim=1)
137
- prob = probs[0][predicted.item()]
138
- if prob.item() > 0.75:
139
- for intent in intents["intents"]:
140
- if tag == intent["tag"]:
141
- response = random.choice(intent["responses"])
142
- st.text("Bot: " + response)
143
- else:
144
- st.text("Bot: I do not understand...")