zinoubm commited on
Commit
5a1b165
·
1 Parent(s): b678100

refactoring the code on the SOLID principles

Browse files
chat.py DELETED
@@ -1,66 +0,0 @@
1
- import os
2
- import openai
3
- from dotenv import load_dotenv
4
- import jsonlines
5
- from pathlib import Path
6
- from utils import (
7
- gpt3_embeddings,
8
- gpt3_completion,
9
- dot_similarity,
10
- load_prompt,
11
- )
12
-
13
- load_dotenv()
14
-
15
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
16
-
17
- openai.api_key = OPENAI_API_KEY
18
-
19
-
20
- def search_index(question, indexes, count=4):
21
- question_embedding = gpt3_embeddings(question)
22
-
23
- simmilarities = []
24
- for index in indexes:
25
- embedding = index["embedding"]
26
- score = dot_similarity(question_embedding, embedding)
27
- simmilarities.append({"index": index, "score": score})
28
-
29
- sorted_similarities = sorted(
30
- simmilarities, key=lambda x: x["score"], reverse=True
31
- )
32
-
33
- return sorted_similarities[:count]
34
-
35
-
36
- if __name__ == "__main__":
37
- with jsonlines.open(Path("./index") / "index.jsonl") as passages:
38
- indexes = list(passages)
39
-
40
- while True:
41
- question = input("User >")
42
-
43
- search_results = search_index(question=question, indexes=indexes, count=2)
44
-
45
- answers = []
46
- for result in search_results:
47
- print("iterating over answering questions")
48
-
49
- prompt = (
50
- load_prompt("prompts\question_answering.txt")
51
- .replace("<<PASSAGE>>", result["index"]["content"])
52
- .replace("<<QUESTION>>", question)
53
- )
54
-
55
- answer = gpt3_completion(
56
- prompt=prompt, max_tokens=80, model="text-curie-001"
57
- )
58
- answers.append(answer)
59
-
60
- prompt = load_prompt("prompts\passage_summarization.txt").replace(
61
- "<<PASSAGE>>", "\n".join(answers)
62
- )
63
-
64
- final_answer = gpt3_completion(prompt=prompt)
65
-
66
- print(f"Bot: {final_answer}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
index/build_index.py CHANGED
@@ -7,7 +7,7 @@ import openai
7
  import textwrap
8
  import jsonlines
9
 
10
- from utils import gpt3_embeddings
11
 
12
  load_dotenv()
13
 
 
7
  import textwrap
8
  import jsonlines
9
 
10
+ from src.utils import gpt3_embeddings
11
 
12
  load_dotenv()
13
 
src/chat.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ from dotenv import load_dotenv
4
+ from index import IndexSearchEngine
5
+ from gpt_3_manager import Gpt3Manager
6
+ from prompt import QuestionAnsweringPrompt, PassageSummarizationPrompt
7
+
8
+
9
+ load_dotenv()
10
+
11
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
12
+
13
+ openai.api_key = OPENAI_API_KEY
14
+
15
+
16
+ class ChatBot:
17
+ def __init__(self, index_search_engine: IndexSearchEngine):
18
+ self.index_search_engine = index_search_engine
19
+
20
+ def ask(self, question):
21
+ search_result = self.index_search_engine.search(question=question)
22
+
23
+ answers = []
24
+ for result in search_result:
25
+ print("iterating over answering questions")
26
+
27
+ question_answering_prompt = QuestionAnsweringPrompt.load(
28
+ "prompts\question_answering.txt"
29
+ )
30
+
31
+ answer = Gpt3Manager.get_completion(
32
+ prompt=question_answering_prompt, max_tokens=80, model="text-curie-001"
33
+ )
34
+ answers.append(answer)
35
+
36
+ passage_summarization_prompt = PassageSummarizationPrompt.load(
37
+ "prompts\passage_summarization.txt"
38
+ )
39
+
40
+ final_answer = Gpt3Manager.get_completion(prompt=passage_summarization_prompt)
41
+ return final_answer
src/gpt_3_manager.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+
4
+ class Gpt3Manager:
5
+ def __init__(self, api_key):
6
+ openai.api_key = api_key
7
+
8
+ def get_completion(prompt, max_tokens=128, model="text-davinci-003"):
9
+ response = None
10
+ try:
11
+ response = openai.Completion.create(
12
+ model=model,
13
+ prompt=prompt,
14
+ max_tokens=max_tokens,
15
+ )["choices"][0]["text"]
16
+
17
+ except Exception as err:
18
+ print(f"Sorry, There was a problem \n\n {err}")
19
+
20
+ return response
21
+
22
+ def get_embedding(text, model="text-similarity-ada-001"):
23
+ text = text.replace("\n", " ")
24
+ embedding = None
25
+ try:
26
+ embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
27
+ "embedding"
28
+ ]
29
+ except Exception as err:
30
+ print(f"Sorry, There was a problem {err}")
31
+
32
+ return embedding
src/index.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import jsonlines
3
+ from gpt_3_manager import Gpt3Manager
4
+ from src.utils import dot_similarity
5
+
6
+
7
+ class Index(ABC):
8
+ @abstractmethod
9
+ def load(self, path):
10
+ pass
11
+
12
+
13
+ class JsonLinesIndex(Index):
14
+ def __init__(self):
15
+ pass
16
+
17
+ def load(self, path):
18
+ with jsonlines.open(path) as passages:
19
+ indexes = list(passages)
20
+ return indexes
21
+
22
+
23
+ class IndexSearchEngine:
24
+ def __init__(self, index):
25
+ index = index
26
+
27
+ def search(self, question, indexes, count=4):
28
+ question_embedding = Gpt3Manager.get_embedding(question)
29
+
30
+ simmilarities = []
31
+ for index in indexes:
32
+ embedding = index["embedding"]
33
+ score = dot_similarity(question_embedding, embedding)
34
+ simmilarities.append({"index": index, "score": score})
35
+
36
+ sorted_similarities = sorted(
37
+ simmilarities, key=lambda x: x["score"], reverse=True
38
+ )
39
+
40
+ return sorted_similarities[:count]
src/prompt.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class Prompt(ABC):
5
+ def load_prompt(path):
6
+ with open(path) as f:
7
+ lines = f.readlines()
8
+ return "".join(lines)
9
+
10
+ @abstractmethod
11
+ def load(self, path):
12
+ pass
13
+
14
+
15
+ class QuestionAnsweringPrompt(Prompt):
16
+ def __init__(self, result, question):
17
+ result = result
18
+ question = question
19
+
20
+ def load(self, path):
21
+ prompt = (
22
+ self.load_prompt(path)
23
+ .replace("<<PASSAGE>>", self.result["index"]["content"])
24
+ .replace("<<QUESTION>>", self.question)
25
+ )
26
+ return prompt
27
+
28
+
29
+ class PassageSummarizationPrompt(Prompt):
30
+ def __init__(self, answers):
31
+ self.answers = answers
32
+
33
+ def load(self, path):
34
+ prompt = self.load_prompt(path).replace("<<PASSAGE>>", "\n".join(self.answers))
35
+ return prompt
utils.py → src/utils.py RENAMED
@@ -1,35 +1,6 @@
1
- import openai
2
  import numpy as np
3
 
4
 
5
- def gpt3_embeddings(text, model="text-similarity-ada-001"):
6
- text = text.replace("\n", " ")
7
- embedding = None
8
- try:
9
- embedding = openai.Embedding.create(input=[text], model=model)["data"][0][
10
- "embedding"
11
- ]
12
- except Exception as err:
13
- print(f"Sorry, There was a problem {err}")
14
-
15
- return embedding
16
-
17
-
18
- def gpt3_completion(prompt, max_tokens=128, model="text-davinci-003"):
19
- response = None
20
- try:
21
- response = openai.Completion.create(
22
- model=model,
23
- prompt=prompt,
24
- max_tokens=max_tokens,
25
- )["choices"][0]["text"]
26
-
27
- except Exception as err:
28
- print(f"Sorry, There was a problem \n\n {err}")
29
-
30
- return response
31
-
32
-
33
  def load_prompt(path):
34
  with open(path) as f:
35
  lines = f.readlines()
 
 
1
  import numpy as np
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def load_prompt(path):
5
  with open(path) as f:
6
  lines = f.readlines()