Spaces:
No application file
No application file
refactoring the code on the SOLID principles
Browse files- chat.py +0 -66
- index/build_index.py +1 -1
- src/chat.py +41 -0
- src/gpt_3_manager.py +32 -0
- src/index.py +40 -0
- src/prompt.py +35 -0
- utils.py → src/utils.py +0 -29
chat.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import openai
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
import jsonlines
|
5 |
-
from pathlib import Path
|
6 |
-
from utils import (
|
7 |
-
gpt3_embeddings,
|
8 |
-
gpt3_completion,
|
9 |
-
dot_similarity,
|
10 |
-
load_prompt,
|
11 |
-
)
|
12 |
-
|
13 |
-
load_dotenv()
|
14 |
-
|
15 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
16 |
-
|
17 |
-
openai.api_key = OPENAI_API_KEY
|
18 |
-
|
19 |
-
|
20 |
-
def search_index(question, indexes, count=4):
|
21 |
-
question_embedding = gpt3_embeddings(question)
|
22 |
-
|
23 |
-
simmilarities = []
|
24 |
-
for index in indexes:
|
25 |
-
embedding = index["embedding"]
|
26 |
-
score = dot_similarity(question_embedding, embedding)
|
27 |
-
simmilarities.append({"index": index, "score": score})
|
28 |
-
|
29 |
-
sorted_similarities = sorted(
|
30 |
-
simmilarities, key=lambda x: x["score"], reverse=True
|
31 |
-
)
|
32 |
-
|
33 |
-
return sorted_similarities[:count]
|
34 |
-
|
35 |
-
|
36 |
-
if __name__ == "__main__":
|
37 |
-
with jsonlines.open(Path("./index") / "index.jsonl") as passages:
|
38 |
-
indexes = list(passages)
|
39 |
-
|
40 |
-
while True:
|
41 |
-
question = input("User >")
|
42 |
-
|
43 |
-
search_results = search_index(question=question, indexes=indexes, count=2)
|
44 |
-
|
45 |
-
answers = []
|
46 |
-
for result in search_results:
|
47 |
-
print("iterating over answering questions")
|
48 |
-
|
49 |
-
prompt = (
|
50 |
-
load_prompt("prompts\question_answering.txt")
|
51 |
-
.replace("<<PASSAGE>>", result["index"]["content"])
|
52 |
-
.replace("<<QUESTION>>", question)
|
53 |
-
)
|
54 |
-
|
55 |
-
answer = gpt3_completion(
|
56 |
-
prompt=prompt, max_tokens=80, model="text-curie-001"
|
57 |
-
)
|
58 |
-
answers.append(answer)
|
59 |
-
|
60 |
-
prompt = load_prompt("prompts\passage_summarization.txt").replace(
|
61 |
-
"<<PASSAGE>>", "\n".join(answers)
|
62 |
-
)
|
63 |
-
|
64 |
-
final_answer = gpt3_completion(prompt=prompt)
|
65 |
-
|
66 |
-
print(f"Bot: {final_answer}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index/build_index.py
CHANGED
@@ -7,7 +7,7 @@ import openai
|
|
7 |
import textwrap
|
8 |
import jsonlines
|
9 |
|
10 |
-
from utils import gpt3_embeddings
|
11 |
|
12 |
load_dotenv()
|
13 |
|
|
|
7 |
import textwrap
|
8 |
import jsonlines
|
9 |
|
10 |
+
from src.utils import gpt3_embeddings
|
11 |
|
12 |
load_dotenv()
|
13 |
|
src/chat.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from index import IndexSearchEngine
|
5 |
+
from gpt_3_manager import Gpt3Manager
|
6 |
+
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt
|
7 |
+
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
12 |
+
|
13 |
+
openai.api_key = OPENAI_API_KEY
|
14 |
+
|
15 |
+
|
16 |
+
class ChatBot:
|
17 |
+
def __init__(self, index_search_engine: IndexSearchEngine):
|
18 |
+
self.index_search_engine = index_search_engine
|
19 |
+
|
20 |
+
def ask(self, question):
|
21 |
+
search_result = self.index_search_engine.search(question=question)
|
22 |
+
|
23 |
+
answers = []
|
24 |
+
for result in search_result:
|
25 |
+
print("iterating over answering questions")
|
26 |
+
|
27 |
+
question_answering_prompt = QuestionAnsweringPrompt.load(
|
28 |
+
"prompts\question_answering.txt"
|
29 |
+
)
|
30 |
+
|
31 |
+
answer = Gpt3Manager.get_completion(
|
32 |
+
prompt=question_answering_prompt, max_tokens=80, model="text-curie-001"
|
33 |
+
)
|
34 |
+
answers.append(answer)
|
35 |
+
|
36 |
+
passage_summarization_prompt = PassageSummarizationPrompt.load(
|
37 |
+
"prompts\passage_summarization.txt"
|
38 |
+
)
|
39 |
+
|
40 |
+
final_answer = Gpt3Manager.get_completion(prompt=passage_summarization_prompt)
|
41 |
+
return final_answer
|
src/gpt_3_manager.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
|
4 |
+
class Gpt3Manager:
|
5 |
+
def __init__(self, api_key):
|
6 |
+
openai.api_key = api_key
|
7 |
+
|
8 |
+
def get_completion(prompt, max_tokens=128, model="text-davinci-003"):
|
9 |
+
response = None
|
10 |
+
try:
|
11 |
+
response = openai.Completion.create(
|
12 |
+
model=model,
|
13 |
+
prompt=prompt,
|
14 |
+
max_tokens=max_tokens,
|
15 |
+
)["choices"][0]["text"]
|
16 |
+
|
17 |
+
except Exception as err:
|
18 |
+
print(f"Sorry, There was a problem \n\n {err}")
|
19 |
+
|
20 |
+
return response
|
21 |
+
|
22 |
+
def get_embedding(text, model="text-similarity-ada-001"):
|
23 |
+
text = text.replace("\n", " ")
|
24 |
+
embedding = None
|
25 |
+
try:
|
26 |
+
embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
|
27 |
+
"embedding"
|
28 |
+
]
|
29 |
+
except Exception as err:
|
30 |
+
print(f"Sorry, There was a problem {err}")
|
31 |
+
|
32 |
+
return embedding
|
src/index.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import jsonlines
|
3 |
+
from gpt_3_manager import Gpt3Manager
|
4 |
+
from src.utils import dot_similarity
|
5 |
+
|
6 |
+
|
7 |
+
class Index(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def load(self, path):
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
class JsonLinesIndex(Index):
|
14 |
+
def __init__(self):
|
15 |
+
pass
|
16 |
+
|
17 |
+
def load(self, path):
|
18 |
+
with jsonlines.open(path) as passages:
|
19 |
+
indexes = list(passages)
|
20 |
+
return indexes
|
21 |
+
|
22 |
+
|
23 |
+
class IndexSearchEngine:
|
24 |
+
def __init__(self, index):
|
25 |
+
index = index
|
26 |
+
|
27 |
+
def search(self, question, indexes, count=4):
|
28 |
+
question_embedding = Gpt3Manager.get_embedding(question)
|
29 |
+
|
30 |
+
simmilarities = []
|
31 |
+
for index in indexes:
|
32 |
+
embedding = index["embedding"]
|
33 |
+
score = dot_similarity(question_embedding, embedding)
|
34 |
+
simmilarities.append({"index": index, "score": score})
|
35 |
+
|
36 |
+
sorted_similarities = sorted(
|
37 |
+
simmilarities, key=lambda x: x["score"], reverse=True
|
38 |
+
)
|
39 |
+
|
40 |
+
return sorted_similarities[:count]
|
src/prompt.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class Prompt(ABC):
|
5 |
+
def load_prompt(path):
|
6 |
+
with open(path) as f:
|
7 |
+
lines = f.readlines()
|
8 |
+
return "".join(lines)
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def load(self, path):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class QuestionAnsweringPrompt(Prompt):
|
16 |
+
def __init__(self, result, question):
|
17 |
+
result = result
|
18 |
+
question = question
|
19 |
+
|
20 |
+
def load(self, path):
|
21 |
+
prompt = (
|
22 |
+
self.load_prompt(path)
|
23 |
+
.replace("<<PASSAGE>>", self.result["index"]["content"])
|
24 |
+
.replace("<<QUESTION>>", self.question)
|
25 |
+
)
|
26 |
+
return prompt
|
27 |
+
|
28 |
+
|
29 |
+
class PassageSummarizationPrompt(Prompt):
|
30 |
+
def __init__(self, answers):
|
31 |
+
self.answers = answers
|
32 |
+
|
33 |
+
def load(self, path):
|
34 |
+
prompt = self.load_prompt(path).replace("<<PASSAGE>>", "\n".join(self.answers))
|
35 |
+
return prompt
|
utils.py → src/utils.py
RENAMED
@@ -1,35 +1,6 @@
|
|
1 |
-
import openai
|
2 |
import numpy as np
|
3 |
|
4 |
|
5 |
-
def gpt3_embeddings(text, model="text-similarity-ada-001"):
|
6 |
-
text = text.replace("\n", " ")
|
7 |
-
embedding = None
|
8 |
-
try:
|
9 |
-
embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
|
10 |
-
"embedding"
|
11 |
-
]
|
12 |
-
except Exception as err:
|
13 |
-
print(f"Sorry, There was a problem {err}")
|
14 |
-
|
15 |
-
return embedding
|
16 |
-
|
17 |
-
|
18 |
-
def gpt3_completion(prompt, max_tokens=128, model="text-davinci-003"):
|
19 |
-
response = None
|
20 |
-
try:
|
21 |
-
response = openai.Completion.create(
|
22 |
-
model=model,
|
23 |
-
prompt=prompt,
|
24 |
-
max_tokens=max_tokens,
|
25 |
-
)["choices"][0]["text"]
|
26 |
-
|
27 |
-
except Exception as err:
|
28 |
-
print(f"Sorry, There was a problem \n\n {err}")
|
29 |
-
|
30 |
-
return response
|
31 |
-
|
32 |
-
|
33 |
def load_prompt(path):
|
34 |
with open(path) as f:
|
35 |
lines = f.readlines()
|
|
|
|
|
1 |
import numpy as np
|
2 |
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def load_prompt(path):
|
5 |
with open(path) as f:
|
6 |
lines = f.readlines()
|