Spaces:
Sleeping
Sleeping
update
Browse files- chatbot.py +47 -0
- main.py +13 -38
- models/fine-tuned-gpt2/config.json +1 -1
- models/fine-tuned-gpt2/config.json:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/generation_config.json +1 -1
- models/fine-tuned-gpt2/generation_config.json:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/merges.txt:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/model.safetensors +1 -1
- models/fine-tuned-gpt2/model.safetensors:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/special_tokens_map.json:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/tokenizer_config.json:Zone.Identifier +3 -0
- models/fine-tuned-gpt2/vocab.json:Zone.Identifier +3 -0
chatbot.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
2 |
+
import torch
|
3 |
+
from datasets import load_dataset
|
4 |
+
import pandas as pd
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
class ChatBot:
|
9 |
+
def __init__(self):
|
10 |
+
self.directory = 'models/fine-tuned-gpt2'
|
11 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(self.directory)
|
12 |
+
self.model = GPT2LMHeadModel.from_pretrained(self.directory)
|
13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
self.model.to(self.device)
|
15 |
+
|
16 |
+
def generate_response(self, history):
|
17 |
+
combined_prompt = ""
|
18 |
+
# self.tokenizer.eos_token_id = '<|endoftext|>'
|
19 |
+
if len(history.user) > 7:
|
20 |
+
history.user = history.user[-7:]
|
21 |
+
history.ai = history.ai[-6:]
|
22 |
+
# Iterate over user and AI messages
|
23 |
+
for user_message, ai_message in zip(history.user, history.ai):
|
24 |
+
combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
|
25 |
+
# Include the last user message in the prompt for response generation
|
26 |
+
if history.user:
|
27 |
+
combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
|
28 |
+
# Tokenize and generate response
|
29 |
+
inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
|
30 |
+
attention_mask = torch.ones(inputs.shape, device=self.device)
|
31 |
+
outputs = self.model.generate(
|
32 |
+
inputs,
|
33 |
+
max_length=500, # Adjust length as needed
|
34 |
+
num_beams=5,
|
35 |
+
early_stopping=True,
|
36 |
+
no_repeat_ngram_size=2,
|
37 |
+
temperature=0.7,
|
38 |
+
top_k=50,
|
39 |
+
top_p=0.95,
|
40 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
41 |
+
attention_mask=attention_mask,
|
42 |
+
repetition_penalty=1.2
|
43 |
+
)
|
44 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
45 |
+
# response = response.replace(combined_prompt, "").split(".")[0]#.replace("(user 1's name)",'AI').replace("(user 2's name)",'AI').replace("[user 1's name]",'AI').replace('<user>','')
|
46 |
+
# print('here:\n', combined_prompt,'\n\n response:\n', response,'\n\n edit-resposne: \n', response.replace(combined_prompt, "").replace('(name)','AI').split(".")[0],'\n\n')
|
47 |
+
return response.replace(combined_prompt, "").split(".")[0]
|
main.py
CHANGED
@@ -1,22 +1,11 @@
|
|
1 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
2 |
-
import torch
|
3 |
-
from datasets import load_dataset
|
4 |
-
import pandas as pd
|
5 |
-
import re
|
6 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
7 |
from fastapi import FastAPI
|
8 |
from fastapi.middleware.cors import CORSMiddleware
|
9 |
-
|
10 |
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
|
15 |
-
dir = 'models/fine-tuned-gpt2'
|
16 |
-
tokenizer = GPT2Tokenizer.from_pretrained(dir)
|
17 |
-
model = GPT2LMHeadModel.from_pretrained(dir)
|
18 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
-
model.to(device)
|
20 |
|
21 |
# Add CORS middleware to allow any origin
|
22 |
app.add_middleware(
|
@@ -29,38 +18,24 @@ app.add_middleware(
|
|
29 |
|
30 |
@app.get("/")
|
31 |
def root():
|
32 |
-
return
|
33 |
|
34 |
# Define the Pydantic model to parse JSON input
|
|
|
|
|
35 |
class HistoryRequest(BaseModel):
|
36 |
user: list[str]
|
37 |
ai: list[str]
|
38 |
|
39 |
@app.post("/generate")
|
40 |
def generate_response(history: HistoryRequest):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
combined_prompt += f"<user> {user_message}\n<AI> {ai_message}\n"
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
# Tokenize and generate response
|
52 |
-
inputs = tokenizer.encode(combined_prompt, return_tensors="pt").to(device)
|
53 |
-
outputs = model.generate(
|
54 |
-
inputs,
|
55 |
-
max_length=150, # Adjust length as needed
|
56 |
-
num_beams=5,
|
57 |
-
early_stopping=True,
|
58 |
-
no_repeat_ngram_size=2,
|
59 |
-
temperature=0.7,
|
60 |
-
top_k=50,
|
61 |
-
top_p=0.95
|
62 |
-
)
|
63 |
-
|
64 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
65 |
-
response = response.replace(combined_prompt, "").split(".")[0]
|
66 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
import chatbot
|
4 |
|
5 |
|
6 |
app = FastAPI()
|
7 |
|
8 |
+
model = None
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Add CORS middleware to allow any origin
|
11 |
app.add_middleware(
|
|
|
18 |
|
19 |
@app.get("/")
|
20 |
def root():
|
21 |
+
return "Hello World"
|
22 |
|
23 |
# Define the Pydantic model to parse JSON input
|
24 |
+
|
25 |
+
from pydantic import BaseModel
|
26 |
class HistoryRequest(BaseModel):
|
27 |
user: list[str]
|
28 |
ai: list[str]
|
29 |
|
30 |
@app.post("/generate")
|
31 |
def generate_response(history: HistoryRequest):
|
32 |
+
try:
|
33 |
+
model
|
34 |
+
except:
|
35 |
+
model = chatbot.ChatBot()
|
|
|
36 |
|
37 |
+
if type(model) != type(chatbot.ChatBot()):
|
38 |
+
model = chatbot.ChatBot()
|
39 |
+
|
40 |
+
response = model.generate_response(history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
return response
|
models/fine-tuned-gpt2/config.json
CHANGED
@@ -33,7 +33,7 @@
|
|
33 |
}
|
34 |
},
|
35 |
"torch_dtype": "float32",
|
36 |
-
"transformers_version": "4.
|
37 |
"use_cache": true,
|
38 |
"vocab_size": 50257
|
39 |
}
|
|
|
33 |
}
|
34 |
},
|
35 |
"torch_dtype": "float32",
|
36 |
+
"transformers_version": "4.44.0",
|
37 |
"use_cache": true,
|
38 |
"vocab_size": 50257
|
39 |
}
|
models/fine-tuned-gpt2/config.json:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/generation_config.json
CHANGED
@@ -2,5 +2,5 @@
|
|
2 |
"_from_model_config": true,
|
3 |
"bos_token_id": 50256,
|
4 |
"eos_token_id": 50256,
|
5 |
-
"transformers_version": "4.
|
6 |
}
|
|
|
2 |
"_from_model_config": true,
|
3 |
"bos_token_id": 50256,
|
4 |
"eos_token_id": 50256,
|
5 |
+
"transformers_version": "4.44.0"
|
6 |
}
|
models/fine-tuned-gpt2/generation_config.json:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/merges.txt:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 497774208
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8be8247018b9ae965bcf6d6e3edaa797753fcf42623b65efa34973d31dae6aa3
|
3 |
size 497774208
|
models/fine-tuned-gpt2/model.safetensors:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/special_tokens_map.json:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/tokenizer_config.json:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|
models/fine-tuned-gpt2/vocab.json:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
HostUrl=https://www.kaggle.com/
|