added flan model
Browse files- ChitChat/common/utils.py +16 -7
- ChitChat/resources/routes.py +10 -2
- app.py +4 -4
- instance/site.db +0 -0
- requirements.txt +1 -0
ChitChat/common/utils.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
2 |
import torch
|
3 |
from flask import current_app
|
4 |
from ChitChat import db
|
5 |
from ChitChat.models import ChatHistory
|
6 |
|
7 |
-
|
8 |
default_model = 'microsoft/DialoGPT-medium'
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
|
|
|
|
|
14 |
|
15 |
def getChatHistory(user):
|
16 |
if user.history is None:
|
@@ -34,11 +38,11 @@ def saveChatHistory(user, chat_history_ids):
|
|
34 |
def conversation(user, userInput):
|
35 |
chat_history_ids = getChatHistory(user)
|
36 |
# print(chat_history_ids)
|
37 |
-
user_input_ids =
|
38 |
# print(user_input_ids)
|
39 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
40 |
# print(bot_input_ids)
|
41 |
-
chat_history_ids =
|
42 |
bot_input_ids,
|
43 |
max_length = 500,
|
44 |
no_repeat_ngram_size = 3,
|
@@ -49,4 +53,9 @@ def conversation(user, userInput):
|
|
49 |
)
|
50 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
51 |
saveChatHistory(user, chat_history_ids)
|
52 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
3 |
import torch
|
4 |
from flask import current_app
|
5 |
from ChitChat import db
|
6 |
from ChitChat.models import ChatHistory
|
7 |
|
8 |
+
small_model_name = 'Th3BossC/DialoGPT-medium-AICLUB_NITC'
|
9 |
default_model = 'microsoft/DialoGPT-medium'
|
10 |
+
large_model_name = 'google/flan-t5-large'
|
11 |
|
12 |
+
small_model = AutoModelForCausalLM.from_pretrained(small_model_name)
|
13 |
+
small_tokenizer = AutoTokenizer.from_pretrained(default_model)
|
14 |
+
small_tokenizer.pad_token = small_tokenizer.eos_token
|
15 |
|
16 |
+
large_model = T5ForConditionalGeneration.from_pretrained(large_model_name)
|
17 |
+
large_tokenizer = T5Tokenizer.from_pretrained(large_model_name)
|
18 |
|
19 |
def getChatHistory(user):
|
20 |
if user.history is None:
|
|
|
38 |
def conversation(user, userInput):
|
39 |
chat_history_ids = getChatHistory(user)
|
40 |
# print(chat_history_ids)
|
41 |
+
user_input_ids = small_tokenizer.encode(userInput + small_tokenizer.eos_token, return_tensors = "pt")
|
42 |
# print(user_input_ids)
|
43 |
bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
|
44 |
# print(bot_input_ids)
|
45 |
+
chat_history_ids = small_model.generate(
|
46 |
bot_input_ids,
|
47 |
max_length = 500,
|
48 |
no_repeat_ngram_size = 3,
|
|
|
53 |
)
|
54 |
# print(f"chat_history_ids : {type(chat_history_ids)}")
|
55 |
saveChatHistory(user, chat_history_ids)
|
56 |
+
return small_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens = True)
|
57 |
+
|
58 |
+
def complexChat(userInput):
|
59 |
+
input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
|
60 |
+
outputs = large_model.generate(input_ids)
|
61 |
+
return large_tokenizer.decode(outputs[0], skip_special_tokens = True)
|
ChitChat/resources/routes.py
CHANGED
@@ -2,7 +2,7 @@ from flask import Blueprint, request, current_app
|
|
2 |
from flask_restful import Api, Resource
|
3 |
from ChitChat.models import ChatHistory
|
4 |
from ChitChat import bcrypt, db
|
5 |
-
from ChitChat.common.utils import conversation
|
6 |
|
7 |
|
8 |
resources = Blueprint('resources', __name__)
|
@@ -50,7 +50,15 @@ class ChatBot(Resource):
|
|
50 |
|
51 |
reply = conversation(user, userInput)
|
52 |
return {'reply' : reply}
|
53 |
-
api.add_resource(ChatBot, '/chat/<string:user_id>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
|
|
|
2 |
from flask_restful import Api, Resource
|
3 |
from ChitChat.models import ChatHistory
|
4 |
from ChitChat import bcrypt, db
|
5 |
+
from ChitChat.common.utils import conversation, complexChat
|
6 |
|
7 |
|
8 |
resources = Blueprint('resources', __name__)
|
|
|
50 |
|
51 |
reply = conversation(user, userInput)
|
52 |
return {'reply' : reply}
|
53 |
+
api.add_resource(ChatBot, '/chat/dumbbot/<string:user_id>')
|
54 |
+
|
55 |
+
|
56 |
+
class ComplexChatBot(Resource):
|
57 |
+
def post(self):
|
58 |
+
userInput = request.json['user']
|
59 |
+
reply = complexChat(userInput)
|
60 |
+
return {'reply' : reply}
|
61 |
+
api.add_resource(ComplexChatBot, '/chat/smartbot')
|
62 |
|
63 |
|
64 |
|
app.py
CHANGED
@@ -3,8 +3,8 @@ from ChitChat import create_app
|
|
3 |
|
4 |
app = create_app()
|
5 |
|
6 |
-
# if __name__ == '__main__':
|
7 |
-
# app.run(debug = True, port = 5000)
|
8 |
-
|
9 |
if __name__ == '__main__':
|
10 |
-
app.run(debug =
|
|
|
|
|
|
|
|
3 |
|
4 |
app = create_app()
|
5 |
|
|
|
|
|
|
|
6 |
if __name__ == '__main__':
|
7 |
+
app.run(debug = True, port = 5000)
|
8 |
+
|
9 |
+
# if __name__ == '__main__':
|
10 |
+
# app.run(debug = False, host = "0.0.0.0", port = 7860)
|
instance/site.db
CHANGED
Binary files a/instance/site.db and b/instance/site.db differ
|
|
requirements.txt
CHANGED
@@ -27,6 +27,7 @@ PyYAML==6.0
|
|
27 |
regex==2023.6.3
|
28 |
requests==2.31.0
|
29 |
safetensors==0.3.1
|
|
|
30 |
six==1.16.0
|
31 |
SQLAlchemy==2.0.16
|
32 |
sympy==1.12
|
|
|
27 |
regex==2023.6.3
|
28 |
requests==2.31.0
|
29 |
safetensors==0.3.1
|
30 |
+
sentencepiece==0.1.99
|
31 |
six==1.16.0
|
32 |
SQLAlchemy==2.0.16
|
33 |
sympy==1.12
|