Spaces:
Sleeping
Sleeping
update
Browse files- chatbot.py +9 -6
- main.py +31 -7
chatbot.py
CHANGED
@@ -4,27 +4,30 @@ from datasets import load_dataset
|
|
4 |
import pandas as pd
|
5 |
import re
|
6 |
|
7 |
-
|
8 |
class ChatBot:
|
9 |
-
def __init__(self):
|
10 |
-
self.directory =
|
11 |
-
self.tokenizer =
|
12 |
-
self.model =
|
13 |
-
self.device =
|
14 |
self.model.to(self.device)
|
15 |
|
16 |
def generate_response(self, history):
|
17 |
combined_prompt = ""
|
|
|
18 |
# self.tokenizer.eos_token_id = '<|endoftext|>'
|
19 |
if len(history.user) > 7:
|
20 |
history.user = history.user[-7:]
|
21 |
history.ai = history.ai[-6:]
|
|
|
22 |
# Iterate over user and AI messages
|
23 |
for user_message, ai_message in zip(history.user, history.ai):
|
24 |
combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
|
|
|
25 |
# Include the last user message in the prompt for response generation
|
26 |
if history.user:
|
27 |
combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
|
|
|
28 |
# Tokenize and generate response
|
29 |
inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
|
30 |
attention_mask = torch.ones(inputs.shape, device=self.device)
|
|
|
4 |
import pandas as pd
|
5 |
import re
|
6 |
|
|
|
7 |
class ChatBot:
|
8 |
+
def __init__(self,dir,tokenizer,model,device):
|
9 |
+
self.directory = dir
|
10 |
+
self.tokenizer = tokenizer
|
11 |
+
self.model = model
|
12 |
+
self.device = device
|
13 |
self.model.to(self.device)
|
14 |
|
15 |
def generate_response(self, history):
|
16 |
combined_prompt = ""
|
17 |
+
|
18 |
# self.tokenizer.eos_token_id = '<|endoftext|>'
|
19 |
if len(history.user) > 7:
|
20 |
history.user = history.user[-7:]
|
21 |
history.ai = history.ai[-6:]
|
22 |
+
|
23 |
# Iterate over user and AI messages
|
24 |
for user_message, ai_message in zip(history.user, history.ai):
|
25 |
combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
|
26 |
+
|
27 |
# Include the last user message in the prompt for response generation
|
28 |
if history.user:
|
29 |
combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
|
30 |
+
|
31 |
# Tokenize and generate response
|
32 |
inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
|
33 |
attention_mask = torch.ones(inputs.shape, device=self.device)
|
main.py
CHANGED
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
import chatbot
|
@@ -5,7 +11,6 @@ import chatbot
|
|
5 |
|
6 |
app = FastAPI()
|
7 |
|
8 |
-
|
9 |
# Add CORS middleware to allow any origin
|
10 |
app.add_middleware(
|
11 |
CORSMiddleware,
|
@@ -28,13 +33,32 @@ class HistoryRequest(BaseModel):
|
|
28 |
|
29 |
@app.post("/generate")
|
30 |
async def generate_response(history: HistoryRequest):
|
31 |
-
try:
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
response = model.generate_response(history)
|
40 |
return response
|
|
|
1 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
2 |
+
import torch
|
3 |
+
from datasets import load_dataset
|
4 |
+
import pandas as pd
|
5 |
+
import re
|
6 |
+
|
7 |
from fastapi import FastAPI
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
import chatbot
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
|
|
14 |
# Add CORS middleware to allow any origin
|
15 |
app.add_middleware(
|
16 |
CORSMiddleware,
|
|
|
33 |
|
34 |
@app.post("/generate")
|
35 |
async def generate_response(history: HistoryRequest):
|
36 |
+
# try:
|
37 |
+
# model
|
38 |
+
# print(12321323)
|
39 |
+
# except:
|
40 |
+
# global model
|
41 |
+
# # model = chatbot.ChatBot()
|
42 |
|
43 |
+
global model
|
44 |
+
|
45 |
+
try:
|
46 |
+
# check if model is already loaded
|
47 |
+
if not isinstance(model, chatbot.ChatBot):
|
48 |
+
model = chatbot.ChatBot(
|
49 |
+
'models/fine-tuned-gpt2',
|
50 |
+
GPT2Tokenizer.from_pretrained('models/fine-tuned-gpt2'),
|
51 |
+
GPT2LMHeadModel.from_pretrained('models/fine-tuned-gpt2'),
|
52 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
+
)
|
54 |
+
except NameError:
|
55 |
+
# if model is not defined, load
|
56 |
+
model = chatbot.ChatBot(
|
57 |
+
'models/fine-tuned-gpt2',
|
58 |
+
GPT2Tokenizer.from_pretrained('models/fine-tuned-gpt2'),
|
59 |
+
GPT2LMHeadModel.from_pretrained('models/fine-tuned-gpt2'),
|
60 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
+
)
|
62 |
|
63 |
response = model.generate_response(history)
|
64 |
return response
|