Th3BossC commited on
Commit
c48357a
·
1 Parent(s): c809dd6

added flan model

Browse files
ChitChat/common/utils.py CHANGED
@@ -1,16 +1,20 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  import torch
3
  from flask import current_app
4
  from ChitChat import db
5
  from ChitChat.models import ChatHistory
6
 
7
- model_name = 'Th3BossC/DialoGPT-medium-AICLUB_NITC'
8
  default_model = 'microsoft/DialoGPT-medium'
 
9
 
10
- model = AutoModelForCausalLM.from_pretrained(model_name)
11
- tokenizer = AutoTokenizer.from_pretrained(default_model)
12
- tokenizer.pad_token = tokenizer.eos_token
13
 
 
 
14
 
15
  def getChatHistory(user):
16
  if user.history is None:
@@ -34,11 +38,11 @@ def saveChatHistory(user, chat_history_ids):
34
  def conversation(user, userInput):
35
  chat_history_ids = getChatHistory(user)
36
  # print(chat_history_ids)
37
- user_input_ids = tokenizer.encode(userInput + tokenizer.eos_token, return_tensors = "pt")
38
  # print(user_input_ids)
39
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
40
  # print(bot_input_ids)
41
- chat_history_ids = model.generate(
42
  bot_input_ids,
43
  max_length = 500,
44
  no_repeat_ngram_size = 3,
@@ -49,4 +53,9 @@ def conversation(user, userInput):
49
  )
50
  # print(f"chat_history_ids : {type(chat_history_ids)}")
51
  saveChatHistory(user, chat_history_ids)
52
- return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens = True)
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import torch
4
  from flask import current_app
5
  from ChitChat import db
6
  from ChitChat.models import ChatHistory
7
 
8
+ small_model_name = 'Th3BossC/DialoGPT-medium-AICLUB_NITC'
9
  default_model = 'microsoft/DialoGPT-medium'
10
+ large_model_name = 'google/flan-t5-large'
11
 
12
+ small_model = AutoModelForCausalLM.from_pretrained(small_model_name)
13
+ small_tokenizer = AutoTokenizer.from_pretrained(default_model)
14
+ small_tokenizer.pad_token = small_tokenizer.eos_token
15
 
16
+ large_model = T5ForConditionalGeneration.from_pretrained(large_model_name)
17
+ large_tokenizer = T5Tokenizer.from_pretrained(large_model_name)
18
 
19
  def getChatHistory(user):
20
  if user.history is None:
 
38
  def conversation(user, userInput):
39
  chat_history_ids = getChatHistory(user)
40
  # print(chat_history_ids)
41
+ user_input_ids = small_tokenizer.encode(userInput + small_tokenizer.eos_token, return_tensors = "pt")
42
  # print(user_input_ids)
43
  bot_input_ids = torch.cat([chat_history_ids, user_input_ids], axis = -1) if chat_history_ids is not None else user_input_ids
44
  # print(bot_input_ids)
45
+ chat_history_ids = small_model.generate(
46
  bot_input_ids,
47
  max_length = 500,
48
  no_repeat_ngram_size = 3,
 
53
  )
54
  # print(f"chat_history_ids : {type(chat_history_ids)}")
55
  saveChatHistory(user, chat_history_ids)
56
+ return small_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens = True)
57
+
58
+ def complexChat(userInput):
59
+ input_ids = large_tokenizer(userInput, return_tensors="pt").input_ids
60
+ outputs = large_model.generate(input_ids)
61
+ return large_tokenizer.decode(outputs[0], skip_special_tokens = True)
ChitChat/resources/routes.py CHANGED
@@ -2,7 +2,7 @@ from flask import Blueprint, request, current_app
2
  from flask_restful import Api, Resource
3
  from ChitChat.models import ChatHistory
4
  from ChitChat import bcrypt, db
5
- from ChitChat.common.utils import conversation
6
 
7
 
8
  resources = Blueprint('resources', __name__)
@@ -50,7 +50,15 @@ class ChatBot(Resource):
50
 
51
  reply = conversation(user, userInput)
52
  return {'reply' : reply}
53
- api.add_resource(ChatBot, '/chat/<string:user_id>')
 
 
 
 
 
 
 
 
54
 
55
 
56
 
 
2
  from flask_restful import Api, Resource
3
  from ChitChat.models import ChatHistory
4
  from ChitChat import bcrypt, db
5
+ from ChitChat.common.utils import conversation, complexChat
6
 
7
 
8
  resources = Blueprint('resources', __name__)
 
50
 
51
  reply = conversation(user, userInput)
52
  return {'reply' : reply}
53
+ api.add_resource(ChatBot, '/chat/dumbbot/<string:user_id>')
54
+
55
+
56
+ class ComplexChatBot(Resource):
57
+ def post(self):
58
+ userInput = request.json['user']
59
+ reply = complexChat(userInput)
60
+ return {'reply' : reply}
61
+ api.add_resource(ComplexChatBot, '/chat/smartbot')
62
 
63
 
64
 
app.py CHANGED
@@ -3,8 +3,8 @@ from ChitChat import create_app
3
 
4
  app = create_app()
5
 
6
- # if __name__ == '__main__':
7
- # app.run(debug = True, port = 5000)
8
-
9
  if __name__ == '__main__':
10
- app.run(debug = False, host = "0.0.0.0", port = 7860)
 
 
 
 
3
 
4
  app = create_app()
5
 
 
 
 
6
  if __name__ == '__main__':
7
+ app.run(debug = True, port = 5000)
8
+
9
+ # if __name__ == '__main__':
10
+ # app.run(debug = False, host = "0.0.0.0", port = 7860)
instance/site.db CHANGED
Binary files a/instance/site.db and b/instance/site.db differ
 
requirements.txt CHANGED
@@ -27,6 +27,7 @@ PyYAML==6.0
27
  regex==2023.6.3
28
  requests==2.31.0
29
  safetensors==0.3.1
 
30
  six==1.16.0
31
  SQLAlchemy==2.0.16
32
  sympy==1.12
 
27
  regex==2023.6.3
28
  requests==2.31.0
29
  safetensors==0.3.1
30
+ sentencepiece==0.1.99
31
  six==1.16.0
32
  SQLAlchemy==2.0.16
33
  sympy==1.12