Spaces:
No application file
No application file
refactroing the project to be modular
Browse files- prompts/question_answering.txt +1 -1
- requirements.txt +2 -1
- test/test.py → src/__init__.py +0 -0
- src/gpt_3_manager.py +5 -5
- src/index.py +7 -9
- src/prompt.py +31 -9
- src/tests/__init__.py +1 -0
- src/tests/chat_test.py +28 -0
- src/tests/gpt_3_manager_test.py +21 -0
- src/tests/index_test.py +30 -0
- src/tests/prompt_test.py +62 -0
- src/tests/utils_test.py +14 -0
prompts/question_answering.txt
CHANGED
@@ -4,4 +4,4 @@ passage: <<PASSAGE>>
|
|
4 |
|
5 |
question: <<QUESTION>>
|
6 |
|
7 |
-
answer:
|
|
|
4 |
|
5 |
question: <<QUESTION>>
|
6 |
|
7 |
+
answer:
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ pdfplumber
|
|
2 |
textwrap3
|
3 |
openai
|
4 |
python-dotenv
|
5 |
-
jsonlines
|
|
|
|
2 |
textwrap3
|
3 |
openai
|
4 |
python-dotenv
|
5 |
+
jsonlines
|
6 |
+
pytest
|
test/test.py → src/__init__.py
RENAMED
File without changes
|
src/gpt_3_manager.py
CHANGED
@@ -5,13 +5,13 @@ class Gpt3Manager:
|
|
5 |
def __init__(self, api_key):
|
6 |
openai.api_key = api_key
|
7 |
|
8 |
-
def get_completion(prompt, max_tokens=128, model="text-davinci-003"):
|
9 |
response = None
|
10 |
try:
|
11 |
response = openai.Completion.create(
|
12 |
-
model=model,
|
13 |
prompt=prompt,
|
14 |
max_tokens=max_tokens,
|
|
|
15 |
)["choices"][0]["text"]
|
16 |
|
17 |
except Exception as err:
|
@@ -19,11 +19,11 @@ class Gpt3Manager:
|
|
19 |
|
20 |
return response
|
21 |
|
22 |
-
def get_embedding(
|
23 |
-
|
24 |
embedding = None
|
25 |
try:
|
26 |
-
embedding = openai.Embedding.create(input=[
|
27 |
"embedding"
|
28 |
]
|
29 |
except Exception as err:
|
|
|
5 |
def __init__(self, api_key):
|
6 |
openai.api_key = api_key
|
7 |
|
8 |
+
def get_completion(self, prompt, max_tokens=128, model="text-davinci-003"):
|
9 |
response = None
|
10 |
try:
|
11 |
response = openai.Completion.create(
|
|
|
12 |
prompt=prompt,
|
13 |
max_tokens=max_tokens,
|
14 |
+
model=model,
|
15 |
)["choices"][0]["text"]
|
16 |
|
17 |
except Exception as err:
|
|
|
19 |
|
20 |
return response
|
21 |
|
22 |
+
def get_embedding(self, prompt, model="text-similarity-ada-001"):
|
23 |
+
prompt = prompt.replace("\n", " ")
|
24 |
embedding = None
|
25 |
try:
|
26 |
+
embedding = openai.Embedding.create(input=[prompt], model=model)["data"][0][
|
27 |
"embedding"
|
28 |
]
|
29 |
except Exception as err:
|
src/index.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
import jsonlines
|
3 |
from gpt_3_manager import Gpt3Manager
|
4 |
-
from
|
5 |
|
6 |
|
7 |
class Index(ABC):
|
@@ -11,9 +11,6 @@ class Index(ABC):
|
|
11 |
|
12 |
|
13 |
class JsonLinesIndex(Index):
|
14 |
-
def __init__(self):
|
15 |
-
pass
|
16 |
-
|
17 |
def load(self, path):
|
18 |
with jsonlines.open(path) as passages:
|
19 |
indexes = list(passages)
|
@@ -21,14 +18,15 @@ class JsonLinesIndex(Index):
|
|
21 |
|
22 |
|
23 |
class IndexSearchEngine:
|
24 |
-
def __init__(self,
|
25 |
-
|
|
|
26 |
|
27 |
-
def search(self, question,
|
28 |
-
question_embedding =
|
29 |
|
30 |
simmilarities = []
|
31 |
-
for index in indexes:
|
32 |
embedding = index["embedding"]
|
33 |
score = dot_similarity(question_embedding, embedding)
|
34 |
simmilarities.append({"index": index, "score": score})
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
import jsonlines
|
3 |
from gpt_3_manager import Gpt3Manager
|
4 |
+
from utils import dot_similarity
|
5 |
|
6 |
|
7 |
class Index(ABC):
|
|
|
11 |
|
12 |
|
13 |
class JsonLinesIndex(Index):
|
|
|
|
|
|
|
14 |
def load(self, path):
|
15 |
with jsonlines.open(path) as passages:
|
16 |
indexes = list(passages)
|
|
|
18 |
|
19 |
|
20 |
class IndexSearchEngine:
|
21 |
+
def __init__(self, indexes, gpt_manager):
|
22 |
+
self.indexes = indexes
|
23 |
+
self.gpt_manager = gpt_manager
|
24 |
|
25 |
+
def search(self, question, count=4):
|
26 |
+
question_embedding = self.gpt_manager.get_embedding(prompt=question)
|
27 |
|
28 |
simmilarities = []
|
29 |
+
for index in self.indexes:
|
30 |
embedding = index["embedding"]
|
31 |
score = dot_similarity(question_embedding, embedding)
|
32 |
simmilarities.append({"index": index, "score": score})
|
src/prompt.py
CHANGED
@@ -1,35 +1,57 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
6 |
with open(path) as f:
|
7 |
lines = f.readlines()
|
8 |
return "".join(lines)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@abstractmethod
|
11 |
def load(self, path):
|
12 |
pass
|
13 |
|
14 |
|
15 |
class QuestionAnsweringPrompt(Prompt):
|
16 |
-
def __init__(self,
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def load(self, path):
|
21 |
prompt = (
|
22 |
self.load_prompt(path)
|
23 |
-
.replace("<<PASSAGE>>", self.
|
24 |
.replace("<<QUESTION>>", self.question)
|
25 |
)
|
26 |
return prompt
|
27 |
|
28 |
|
29 |
class PassageSummarizationPrompt(Prompt):
|
30 |
-
def __init__(self,
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
def load(self, path):
|
34 |
-
prompt = self.load_prompt(path).replace("<<PASSAGE>>",
|
35 |
return prompt
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
|
3 |
+
# Prompt Loaders
|
4 |
+
class PromptLoader(ABC):
|
5 |
+
@abstractmethod
|
6 |
+
def load_prompt():
|
7 |
+
pass
|
8 |
|
9 |
+
|
10 |
+
class TextPromptLoader(PromptLoader):
|
11 |
+
def load_prompt(self, path):
|
12 |
with open(path) as f:
|
13 |
lines = f.readlines()
|
14 |
return "".join(lines)
|
15 |
|
16 |
+
|
17 |
+
# Prompts
|
18 |
+
class Prompt(ABC):
|
19 |
+
def __init__(self, prompt_loader: PromptLoader):
|
20 |
+
self.prompt_loader = prompt_loader
|
21 |
+
|
22 |
+
def load_prompt(self, path):
|
23 |
+
return self.prompt_loader.load_prompt(path)
|
24 |
+
|
25 |
@abstractmethod
|
26 |
def load(self, path):
|
27 |
pass
|
28 |
|
29 |
|
30 |
class QuestionAnsweringPrompt(Prompt):
|
31 |
+
def __init__(self, passage, question, prompt_loader):
|
32 |
+
super().__init__(prompt_loader=prompt_loader)
|
33 |
+
self.passage = passage
|
34 |
+
self.question = question
|
35 |
+
|
36 |
+
# trust me, you'll need this later
|
37 |
+
# .replace("<<PASSAGE>>", self.result["index"]["content"])
|
38 |
|
39 |
def load(self, path):
|
40 |
prompt = (
|
41 |
self.load_prompt(path)
|
42 |
+
.replace("<<PASSAGE>>", self.passage)
|
43 |
.replace("<<QUESTION>>", self.question)
|
44 |
)
|
45 |
return prompt
|
46 |
|
47 |
|
48 |
class PassageSummarizationPrompt(Prompt):
|
49 |
+
def __init__(self, passage, prompt_loader):
|
50 |
+
super().__init__(prompt_loader=prompt_loader)
|
51 |
+
self.passage = passage
|
52 |
+
|
53 |
+
# prompt = self.load_prompt(path).replace("<<PASSAGE>>", "\n".join(self.answers))
|
54 |
|
55 |
def load(self, path):
|
56 |
+
prompt = self.load_prompt(path).replace("<<PASSAGE>>", self.passage)
|
57 |
return prompt
|
src/tests/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/tests/chat_test.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from index import IndexSearchEngine
|
4 |
+
from gpt_3_manager import Gpt3Manager
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from chat import ChatBot
|
7 |
+
from index import JsonLinesIndex
|
8 |
+
|
9 |
+
# load_dotenv()
|
10 |
+
|
11 |
+
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
12 |
+
|
13 |
+
|
14 |
+
# def test_chatbot():
|
15 |
+
# path = Path("index") / "index.jsonl"
|
16 |
+
|
17 |
+
# index = JsonLinesIndex()
|
18 |
+
# loaded = index.load(path)
|
19 |
+
# gpt_manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
20 |
+
# engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
21 |
+
|
22 |
+
# chatbot = ChatBot(engine)
|
23 |
+
# answer = chatbot.ask("What does the twitter terms of service does")
|
24 |
+
# print(answer)
|
25 |
+
# # assert 0 == 0
|
26 |
+
|
27 |
+
|
28 |
+
# test_chatbot()
|
src/tests/gpt_3_manager_test.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from gpt_3_manager import Gpt3Manager
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
8 |
+
|
9 |
+
|
10 |
+
def test_gpt3_completion():
|
11 |
+
manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
12 |
+
request = manager.get_completion(
|
13 |
+
prompt="This is a testing prompt", max_tokens=10, model="text-ada-001"
|
14 |
+
)
|
15 |
+
assert request != None
|
16 |
+
|
17 |
+
|
18 |
+
def test_gpt3_embedding():
|
19 |
+
manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
20 |
+
request = manager.get_embedding(prompt="This is a testing prompt")
|
21 |
+
assert request != None
|
src/tests/index_test.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from index import JsonLinesIndex, IndexSearchEngine
|
3 |
+
from gpt_3_manager import Gpt3Manager
|
4 |
+
from pathlib import Path
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
10 |
+
|
11 |
+
|
12 |
+
def test_jsonlines_index():
|
13 |
+
path = Path("index") / "index.jsonl"
|
14 |
+
|
15 |
+
index = JsonLinesIndex()
|
16 |
+
result = index.load(path)
|
17 |
+
|
18 |
+
assert result != None
|
19 |
+
|
20 |
+
|
21 |
+
def test_index_serach_engine():
|
22 |
+
path = Path("index") / "index.jsonl"
|
23 |
+
gpt_manager = Gpt3Manager(OPENAI_API_KEY)
|
24 |
+
index = JsonLinesIndex()
|
25 |
+
loaded = index.load(path)
|
26 |
+
engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
27 |
+
|
28 |
+
results = engine.search(question="What does the twitter tos does")
|
29 |
+
|
30 |
+
assert results != None
|
src/tests/prompt_test.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt, TextPromptLoader
|
3 |
+
|
4 |
+
|
5 |
+
def test_text_prompt_loader():
|
6 |
+
path = Path("prompts") / "question_answering.txt"
|
7 |
+
prompt_loader = TextPromptLoader()
|
8 |
+
|
9 |
+
prompt = prompt_loader.load_prompt(path)
|
10 |
+
testing_prompt = (
|
11 |
+
"Use the passage to write a detailed answer to the following question\n"
|
12 |
+
"\n"
|
13 |
+
"passage: <<PASSAGE>>\n"
|
14 |
+
"\n"
|
15 |
+
"question: <<QUESTION>>\n"
|
16 |
+
"\n"
|
17 |
+
"answer:"
|
18 |
+
)
|
19 |
+
|
20 |
+
assert prompt == testing_prompt
|
21 |
+
|
22 |
+
|
23 |
+
def test_question_answering_prompt():
|
24 |
+
path = Path("prompts") / "question_answering.txt"
|
25 |
+
|
26 |
+
passage = "Hi, I'm foo and I love cycling and programming"
|
27 |
+
question = "What is foo's hobby"
|
28 |
+
|
29 |
+
prompt_loader = TextPromptLoader()
|
30 |
+
prompt = QuestionAnsweringPrompt(passage, question, prompt_loader)
|
31 |
+
loaded_prompt = prompt.load(path)
|
32 |
+
|
33 |
+
testing_prompt = (
|
34 |
+
"Use the passage to write a detailed answer to the following question\n"
|
35 |
+
"\n"
|
36 |
+
"passage: Hi, I'm foo and I love cycling and programming\n"
|
37 |
+
"\n"
|
38 |
+
"question: What is foo's hobby\n"
|
39 |
+
"\n"
|
40 |
+
"answer:"
|
41 |
+
)
|
42 |
+
|
43 |
+
assert loaded_prompt == testing_prompt
|
44 |
+
|
45 |
+
|
46 |
+
def test_passage_summarization_prompt():
|
47 |
+
path = Path("prompts") / "passage_summarization.txt"
|
48 |
+
|
49 |
+
passage = "Hi, I'm foo and I love cycling and programming"
|
50 |
+
|
51 |
+
prompt_loader = TextPromptLoader()
|
52 |
+
prompt = PassageSummarizationPrompt(passage, prompt_loader)
|
53 |
+
loaded_prompt = prompt.load(path)
|
54 |
+
|
55 |
+
testing_prompt = (
|
56 |
+
"Summarize the following passage in detail\n"
|
57 |
+
"passage: Hi, I'm foo and I love cycling and programming\n"
|
58 |
+
"\n"
|
59 |
+
"summary:"
|
60 |
+
)
|
61 |
+
|
62 |
+
assert loaded_prompt == testing_prompt
|
src/tests/utils_test.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from utils import load_prompt
|
3 |
+
|
4 |
+
|
5 |
+
def test_load_prompt_default():
|
6 |
+
path = Path("prompts") / "question_answering.txt"
|
7 |
+
|
8 |
+
with open(path) as f:
|
9 |
+
lines = f.readlines()
|
10 |
+
testing_prompt = "".join(lines)
|
11 |
+
|
12 |
+
prompt = load_prompt(path)
|
13 |
+
|
14 |
+
assert prompt == testing_prompt
|