arcsu1 commited on
Commit
fd25dfc
·
1 Parent(s): 3575d41
Files changed (2) hide show
  1. chatbot.py +9 -6
  2. main.py +31 -7
chatbot.py CHANGED
@@ -4,27 +4,30 @@ from datasets import load_dataset
4
  import pandas as pd
5
  import re
6
 
7
-
8
  class ChatBot:
9
- def __init__(self):
10
- self.directory = 'models/fine-tuned-gpt2'
11
- self.tokenizer = GPT2Tokenizer.from_pretrained(self.directory)
12
- self.model = GPT2LMHeadModel.from_pretrained(self.directory)
13
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  self.model.to(self.device)
15
 
16
  def generate_response(self, history):
17
  combined_prompt = ""
 
18
  # self.tokenizer.eos_token_id = '<|endoftext|>'
19
  if len(history.user) > 7:
20
  history.user = history.user[-7:]
21
  history.ai = history.ai[-6:]
 
22
  # Iterate over user and AI messages
23
  for user_message, ai_message in zip(history.user, history.ai):
24
  combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
 
25
  # Include the last user message in the prompt for response generation
26
  if history.user:
27
  combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
 
28
  # Tokenize and generate response
29
  inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
30
  attention_mask = torch.ones(inputs.shape, device=self.device)
 
4
  import pandas as pd
5
  import re
6
 
 
7
  class ChatBot:
8
+ def __init__(self,dir,tokenizer,model,device):
9
+ self.directory = dir
10
+ self.tokenizer = tokenizer
11
+ self.model = model
12
+ self.device = device
13
  self.model.to(self.device)
14
 
15
  def generate_response(self, history):
16
  combined_prompt = ""
17
+
18
  # self.tokenizer.eos_token_id = '<|endoftext|>'
19
  if len(history.user) > 7:
20
  history.user = history.user[-7:]
21
  history.ai = history.ai[-6:]
22
+
23
  # Iterate over user and AI messages
24
  for user_message, ai_message in zip(history.user, history.ai):
25
  combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
26
+
27
  # Include the last user message in the prompt for response generation
28
  if history.user:
29
  combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
30
+
31
  # Tokenize and generate response
32
  inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
33
  attention_mask = torch.ones(inputs.shape, device=self.device)
main.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import chatbot
@@ -5,7 +11,6 @@ import chatbot
5
 
6
  app = FastAPI()
7
 
8
-
9
  # Add CORS middleware to allow any origin
10
  app.add_middleware(
11
  CORSMiddleware,
@@ -28,13 +33,32 @@ class HistoryRequest(BaseModel):
28
 
29
  @app.post("/generate")
30
  async def generate_response(history: HistoryRequest):
31
- try:
32
- model
33
- except:
34
- model = chatbot.ChatBot()
 
 
35
 
36
- if type(model) != type(chatbot.ChatBot()):
37
- model = chatbot.ChatBot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  response = model.generate_response(history)
40
  return response
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
+ import torch
3
+ from datasets import load_dataset
4
+ import pandas as pd
5
+ import re
6
+
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
  import chatbot
 
11
 
12
  app = FastAPI()
13
 
 
14
  # Add CORS middleware to allow any origin
15
  app.add_middleware(
16
  CORSMiddleware,
 
33
 
34
  @app.post("/generate")
35
  async def generate_response(history: HistoryRequest):
36
+ # try:
37
+ # model
38
+ # print(12321323)
39
+ # except:
40
+ # global model
41
+ # # model = chatbot.ChatBot()
42
 
43
+ global model
44
+
45
+ try:
46
+ # check if model is already loaded
47
+ if not isinstance(model, chatbot.ChatBot):
48
+ model = chatbot.ChatBot(
49
+ 'models/fine-tuned-gpt2',
50
+ GPT2Tokenizer.from_pretrained('models/fine-tuned-gpt2'),
51
+ GPT2LMHeadModel.from_pretrained('models/fine-tuned-gpt2'),
52
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ )
54
+ except NameError:
55
+ # if model is not defined, load
56
+ model = chatbot.ChatBot(
57
+ 'models/fine-tuned-gpt2',
58
+ GPT2Tokenizer.from_pretrained('models/fine-tuned-gpt2'),
59
+ GPT2LMHeadModel.from_pretrained('models/fine-tuned-gpt2'),
60
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ )
62
 
63
  response = model.generate_response(history)
64
  return response