Spaces:
No application file
No application file
adding final details
Browse files- requirements.txt +2 -1
- src/__init__.py +0 -0
- src/chat.py +22 -11
- src/index.py +1 -2
- src/main.py +29 -0
- src/prompt.py +2 -2
- src/tests/__init__.py +0 -1
- src/tests/chat_test.py +16 -13
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ textwrap3
|
|
3 |
openai
|
4 |
python-dotenv
|
5 |
jsonlines
|
6 |
-
pytest
|
|
|
|
3 |
openai
|
4 |
python-dotenv
|
5 |
jsonlines
|
6 |
+
pytest
|
7 |
+
numpy
|
src/__init__.py
DELETED
File without changes
|
src/chat.py
CHANGED
@@ -3,7 +3,8 @@ import openai
|
|
3 |
from dotenv import load_dotenv
|
4 |
from index import IndexSearchEngine
|
5 |
from gpt_3_manager import Gpt3Manager
|
6 |
-
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt
|
|
|
7 |
|
8 |
|
9 |
load_dotenv()
|
@@ -14,28 +15,38 @@ openai.api_key = OPENAI_API_KEY
|
|
14 |
|
15 |
|
16 |
class ChatBot:
|
17 |
-
def __init__(
|
|
|
|
|
18 |
self.index_search_engine = index_search_engine
|
|
|
|
|
19 |
|
20 |
def ask(self, question):
|
21 |
-
search_result = self.index_search_engine.search(question=question)
|
22 |
|
23 |
answers = []
|
24 |
for result in search_result:
|
25 |
print("iterating over answering questions")
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
)
|
30 |
|
31 |
-
answer =
|
32 |
-
prompt=
|
33 |
)
|
34 |
answers.append(answer)
|
35 |
|
36 |
-
passage_summarization_prompt = PassageSummarizationPrompt
|
37 |
-
"
|
|
|
|
|
|
|
|
|
38 |
)
|
39 |
|
40 |
-
final_answer =
|
41 |
return final_answer
|
|
|
3 |
from dotenv import load_dotenv
|
4 |
from index import IndexSearchEngine
|
5 |
from gpt_3_manager import Gpt3Manager
|
6 |
+
from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt, TextPromptLoader
|
7 |
+
from pathlib import Path
|
8 |
|
9 |
|
10 |
load_dotenv()
|
|
|
15 |
|
16 |
|
17 |
class ChatBot:
|
18 |
+
def __init__(
|
19 |
+
self, index_search_engine: IndexSearchEngine, prompt_loader, gpt_manager
|
20 |
+
):
|
21 |
self.index_search_engine = index_search_engine
|
22 |
+
self.prompet_loader = prompt_loader
|
23 |
+
self.gpt_manager = gpt_manager
|
24 |
|
25 |
def ask(self, question):
|
26 |
+
search_result = self.index_search_engine.search(question=question, count=2)
|
27 |
|
28 |
answers = []
|
29 |
for result in search_result:
|
30 |
print("iterating over answering questions")
|
31 |
+
question_answering_prompt = QuestionAnsweringPrompt(
|
32 |
+
passage=result, question=question, prompt_loader=self.prompet_loader
|
33 |
+
)
|
34 |
+
prompt = question_answering_prompt.load(
|
35 |
+
Path("prompts") / "question_answering.txt"
|
36 |
)
|
37 |
|
38 |
+
answer = self.gpt_manager.get_completion(
|
39 |
+
prompt=prompt, max_tokens=80, model="text-curie-001"
|
40 |
)
|
41 |
answers.append(answer)
|
42 |
|
43 |
+
passage_summarization_prompt = PassageSummarizationPrompt(
|
44 |
+
"\n".join(answers), self.prompet_loader
|
45 |
+
)
|
46 |
+
|
47 |
+
prompt = passage_summarization_prompt.load(
|
48 |
+
Path("prompts") / "passage_summarization.txt"
|
49 |
)
|
50 |
|
51 |
+
final_answer = self.gpt_manager.get_completion(prompt=prompt)
|
52 |
return final_answer
|
src/index.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
import jsonlines
|
3 |
-
from gpt_3_manager import Gpt3Manager
|
4 |
from utils import dot_similarity
|
5 |
|
6 |
|
@@ -35,4 +34,4 @@ class IndexSearchEngine:
|
|
35 |
simmilarities, key=lambda x: x["score"], reverse=True
|
36 |
)
|
37 |
|
38 |
-
return sorted_similarities[:count]
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
import jsonlines
|
|
|
3 |
from utils import dot_similarity
|
4 |
|
5 |
|
|
|
34 |
simmilarities, key=lambda x: x["score"], reverse=True
|
35 |
)
|
36 |
|
37 |
+
return [result["index"]["content"] for result in sorted_similarities[:count]]
|
src/main.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from index import IndexSearchEngine
|
5 |
+
from gpt_3_manager import Gpt3Manager
|
6 |
+
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from chat import ChatBot
|
9 |
+
from index import JsonLinesIndex
|
10 |
+
|
11 |
+
from prompt import TextPromptLoader
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
16 |
+
|
17 |
+
|
18 |
+
path = Path("index") / "index.jsonl"
|
19 |
+
|
20 |
+
index = JsonLinesIndex()
|
21 |
+
loaded = index.load(path)
|
22 |
+
gpt_manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
23 |
+
|
24 |
+
engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
25 |
+
loader = TextPromptLoader()
|
26 |
+
chatbot = ChatBot(engine, prompt_loader=loader, gpt_manager=gpt_manager)
|
27 |
+
|
28 |
+
answer = chatbot.ask("What does the twitter terms of service does")
|
29 |
+
print(answer)
|
src/prompt.py
CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
3 |
# Prompt Loaders
|
4 |
class PromptLoader(ABC):
|
5 |
@abstractmethod
|
6 |
-
def load_prompt():
|
7 |
pass
|
8 |
|
9 |
|
@@ -50,7 +50,7 @@ class PassageSummarizationPrompt(Prompt):
|
|
50 |
super().__init__(prompt_loader=prompt_loader)
|
51 |
self.passage = passage
|
52 |
|
53 |
-
# prompt = self.load_prompt(path).replace("<<PASSAGE>>",
|
54 |
|
55 |
def load(self, path):
|
56 |
prompt = self.load_prompt(path).replace("<<PASSAGE>>", self.passage)
|
|
|
3 |
# Prompt Loaders
|
4 |
class PromptLoader(ABC):
|
5 |
@abstractmethod
|
6 |
+
def load_prompt(self, path):
|
7 |
pass
|
8 |
|
9 |
|
|
|
50 |
super().__init__(prompt_loader=prompt_loader)
|
51 |
self.passage = passage
|
52 |
|
53 |
+
# prompt = self.load_prompt(path).replace("<<PASSAGE>>", )
|
54 |
|
55 |
def load(self, path):
|
56 |
prompt = self.load_prompt(path).replace("<<PASSAGE>>", self.passage)
|
src/tests/__init__.py
CHANGED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
src/tests/chat_test.py
CHANGED
@@ -1,28 +1,31 @@
|
|
1 |
import os
|
2 |
from pathlib import Path
|
|
|
3 |
from index import IndexSearchEngine
|
4 |
from gpt_3_manager import Gpt3Manager
|
|
|
5 |
from dotenv import load_dotenv
|
6 |
from chat import ChatBot
|
7 |
from index import JsonLinesIndex
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
11 |
-
|
12 |
|
13 |
|
14 |
-
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
# engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
# # assert 0 == 0
|
26 |
|
|
|
27 |
|
28 |
-
|
|
|
1 |
import os
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
from index import IndexSearchEngine
|
5 |
from gpt_3_manager import Gpt3Manager
|
6 |
+
|
7 |
from dotenv import load_dotenv
|
8 |
from chat import ChatBot
|
9 |
from index import JsonLinesIndex
|
10 |
|
11 |
+
from prompt import TextPromptLoader
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
|
15 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
16 |
|
17 |
|
18 |
+
def test_chatbot():
|
19 |
+
path = Path("index") / "index.jsonl"
|
20 |
|
21 |
+
index = JsonLinesIndex()
|
22 |
+
loaded = index.load(path)
|
23 |
+
gpt_manager = Gpt3Manager(api_key=OPENAI_API_KEY)
|
|
|
24 |
|
25 |
+
engine = IndexSearchEngine(loaded, gpt_manager=gpt_manager)
|
26 |
+
loader = TextPromptLoader()
|
27 |
+
chatbot = ChatBot(engine, prompt_loader=loader, gpt_manager=gpt_manager)
|
|
|
28 |
|
29 |
+
answer = chatbot.ask("What does the twitter terms of service does")
|
30 |
|
31 |
+
assert answer != None
|