Spaces:
Sleeping
Sleeping
maximka608
commited on
Commit
·
b93b2dc
1
Parent(s):
c621bf1
test
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/Project_Default.xml +13 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/nlp.iml +8 -0
- .idea/vcs.xml +7 -0
- app.py +86 -0
- config.py +10 -0
- faiss_index.faiss +3 -0
- metadata.json +3 -0
- preprocessing_text.json +3 -0
- requirements.txt +112 -0
- script/create_vector_base.py +41 -0
- script/preprocessing_text.py +36 -0
- utils/embedding.py +14 -0
- utils/llm.py +28 -0
- utils/vector_base.py +27 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredErrors">
|
6 |
+
<list>
|
7 |
+
<option value="N806" />
|
8 |
+
<option value="N803" />
|
9 |
+
</list>
|
10 |
+
</option>
|
11 |
+
</inspection_tool>
|
12 |
+
</profile>
|
13 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="Python 3.10 (papersRag)" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (papersRag)" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/nlp.iml" filepath="$PROJECT_DIR$/.idea/nlp.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/nlp.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.10 (papersRag)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
6 |
+
</component>
|
7 |
+
</project>
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from utils.vector_base import KnowledgeBase
|
3 |
+
from utils.embedding import Embeddings
|
4 |
+
from utils.llm import LLM
|
5 |
+
from config import config
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
def get_emdedding_model():
|
10 |
+
return Embeddings()
|
11 |
+
|
12 |
+
|
13 |
+
def get_llm(url, api_key):
|
14 |
+
return LLM(url, api_key)
|
15 |
+
|
16 |
+
|
17 |
+
def get_metadata(path):
|
18 |
+
titles, texts = [], []
|
19 |
+
with open(path, 'rb') as file:
|
20 |
+
metadata = json.load(file)
|
21 |
+
for data in metadata:
|
22 |
+
titles.append(data['title'])
|
23 |
+
texts.append(data['text'])
|
24 |
+
return texts, titles
|
25 |
+
|
26 |
+
|
27 |
+
def combine_docs(indexes, texts):
|
28 |
+
result = ""
|
29 |
+
for i, index in enumerate(indexes):
|
30 |
+
result += " [" + str(i + 1) + "] " + texts[index]
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def create_prompt(query, docs):
|
35 |
+
system_prompt = f"""You are a language model integrated into a search and
|
36 |
+
generation system based on relevant documents (RAG system).
|
37 |
+
Your task is to provide answers to the user's queries based on the provided
|
38 |
+
documents. Respond only based on the provided documents. Do not make up
|
39 |
+
information that is not in the sources. If you use data from a document,
|
40 |
+
indicate the document number in square brackets. For example: "This term
|
41 |
+
means such-and-such [1]." If there is no information in the documents,
|
42 |
+
politely explain that the information is not available. Do not alter the
|
43 |
+
content of the sources, convey the information accurately.
|
44 |
+
User query: {query}. Documents: {docs}
|
45 |
+
"""
|
46 |
+
|
47 |
+
return system_prompt
|
48 |
+
|
49 |
+
|
50 |
+
st.title("PaperRAG")
|
51 |
+
st.write("RAG system for scientific papers with selectable search types")
|
52 |
+
|
53 |
+
query = st.text_input("Enter your query", "")
|
54 |
+
search_types = st.multiselect(
|
55 |
+
"Search Types", options=["Vector", "BM25"], default=["Vector", "BM25"]
|
56 |
+
)
|
57 |
+
llm_url = st.text_input("LLM URL", "", placeholder="Enter LLM ENDPOINT")
|
58 |
+
llm_api_key = st.text_input("LLM API Key", "", placeholder="Enter LLM API Key", type="password")
|
59 |
+
|
60 |
+
if st.button("Search"):
|
61 |
+
if query and llm_url and llm_api_key:
|
62 |
+
model = get_emdedding_model()
|
63 |
+
llm = get_llm(llm_url, llm_api_key)
|
64 |
+
|
65 |
+
texts, titles = get_metadata(config.PATH_METADATA)
|
66 |
+
embedding = model.get_query_embedding(query)
|
67 |
+
|
68 |
+
knowledge_base = KnowledgeBase(config.PATH_FAISS, config.PATH_PREPROCESSING_TEXT)
|
69 |
+
|
70 |
+
vector_search = []
|
71 |
+
bm25_search = []
|
72 |
+
|
73 |
+
if "Vector" in search_types:
|
74 |
+
vector_search = knowledge_base.search_by_embedding(embedding, 5)[0].tolist()
|
75 |
+
if "BM25" in search_types:
|
76 |
+
bm25_search = knowledge_base.search_by_BM25(query, 5)
|
77 |
+
|
78 |
+
docs = combine_docs(vector_search + bm25_search, texts)
|
79 |
+
prompt = create_prompt(query, docs)
|
80 |
+
|
81 |
+
response = llm.generate_response(prompt)
|
82 |
+
|
83 |
+
st.subheader("Response")
|
84 |
+
st.write(response)
|
85 |
+
else:
|
86 |
+
st.error("Please fill in all the required fields.")
|
config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
|
4 |
+
class Config:
|
5 |
+
PATH_FAISS = str(Path(__file__).parent / 'faiss_index.faiss')
|
6 |
+
PATH_METADATA = str(Path(__file__).parent / 'metadata.json')
|
7 |
+
PATH_PREPROCESSING_TEXT = str(Path(__file__).parent / 'preprocessing_text.json')
|
8 |
+
|
9 |
+
|
10 |
+
config = Config()
|
faiss_index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1422951bc60fbc02da260a6d9059740149b8724e13f71b7110e440e66bcc9f79
|
3 |
+
size 76847661
|
metadata.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:052c218b62d563adf9d26339d58c1296f22e6674f36f3b55e3675c3865e50d8f
|
3 |
+
size 17923018
|
preprocessing_text.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51ccf934a26b90ca2d1753d51dd3e5e5498121ba7f661ce14a232a9993667bdf
|
3 |
+
size 8317837
|
requirements.txt
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohappyeyeballs==2.4.3
|
3 |
+
aiohttp==3.11.2
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==5.5.0
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.6.2.post1
|
8 |
+
async-timeout==5.0.1
|
9 |
+
attrs==24.2.0
|
10 |
+
blinker==1.9.0
|
11 |
+
Brotli==1.1.0
|
12 |
+
cachetools==5.5.0
|
13 |
+
certifi==2024.8.30
|
14 |
+
cffi==1.17.1
|
15 |
+
charset-normalizer==3.4.0
|
16 |
+
click==8.1.7
|
17 |
+
cryptography==43.0.3
|
18 |
+
datasets==3.1.0
|
19 |
+
dill==0.3.8
|
20 |
+
distro==1.9.0
|
21 |
+
einops==0.8.0
|
22 |
+
exceptiongroup==1.2.2
|
23 |
+
faiss-cpu==1.9.0
|
24 |
+
fastapi==0.115.5
|
25 |
+
ffmpy==0.4.0
|
26 |
+
filelock==3.16.1
|
27 |
+
frozenlist==1.5.0
|
28 |
+
fsspec==2024.9.0
|
29 |
+
gitdb==4.0.11
|
30 |
+
GitPython==3.1.43
|
31 |
+
gradio==5.6.0
|
32 |
+
gradio_client==1.4.3
|
33 |
+
h11==0.14.0
|
34 |
+
httpcore==1.0.7
|
35 |
+
httpx==0.27.2
|
36 |
+
huggingface-hub==0.26.2
|
37 |
+
idna==3.10
|
38 |
+
Jinja2==3.1.4
|
39 |
+
jiter==0.7.1
|
40 |
+
joblib==1.4.2
|
41 |
+
jsonschema==4.23.0
|
42 |
+
jsonschema-specifications==2024.10.1
|
43 |
+
markdown-it-py==3.0.0
|
44 |
+
MarkupSafe==2.1.5
|
45 |
+
mdurl==0.1.2
|
46 |
+
mpmath==1.3.0
|
47 |
+
multidict==6.1.0
|
48 |
+
multiprocess==0.70.16
|
49 |
+
narwhals==1.14.2
|
50 |
+
networkx==3.4.2
|
51 |
+
nltk==3.9.1
|
52 |
+
numpy==2.1.3
|
53 |
+
openai==1.54.4
|
54 |
+
orjson==3.10.11
|
55 |
+
packaging==24.2
|
56 |
+
pandas==2.2.3
|
57 |
+
pdfminer.six==20231228
|
58 |
+
pdfplumber==0.11.4
|
59 |
+
pillow==11.0.0
|
60 |
+
propcache==0.2.0
|
61 |
+
protobuf==5.28.3
|
62 |
+
pyarrow==18.0.0
|
63 |
+
pycparser==2.22
|
64 |
+
pydantic==2.9.2
|
65 |
+
pydantic_core==2.23.4
|
66 |
+
pydeck==0.9.1
|
67 |
+
pydub==0.25.1
|
68 |
+
Pygments==2.18.0
|
69 |
+
pypdfium2==4.30.0
|
70 |
+
python-dateutil==2.9.0.post0
|
71 |
+
python-dotenv==1.0.1
|
72 |
+
python-multipart==0.0.12
|
73 |
+
pytz==2024.2
|
74 |
+
PyYAML==6.0.2
|
75 |
+
rank-bm25==0.2.2
|
76 |
+
referencing==0.35.1
|
77 |
+
regex==2024.11.6
|
78 |
+
requests==2.32.3
|
79 |
+
rich==13.9.4
|
80 |
+
rpds-py==0.21.0
|
81 |
+
ruff==0.7.4
|
82 |
+
safehttpx==0.1.1
|
83 |
+
safetensors==0.4.5
|
84 |
+
scikit-learn==1.5.2
|
85 |
+
scipy==1.14.1
|
86 |
+
semantic-version==2.10.0
|
87 |
+
sentence-transformers==3.3.1
|
88 |
+
shellingham==1.5.4
|
89 |
+
six==1.16.0
|
90 |
+
smmap==5.0.1
|
91 |
+
sniffio==1.3.1
|
92 |
+
starlette==0.41.3
|
93 |
+
sympy==1.13.1
|
94 |
+
tenacity==9.0.0
|
95 |
+
threadpoolctl==3.5.0
|
96 |
+
tokenizers==0.20.3
|
97 |
+
toml==0.10.2
|
98 |
+
tomlkit==0.12.0
|
99 |
+
torch==2.5.1
|
100 |
+
torchaudio==2.5.1
|
101 |
+
torchvision==0.20.1
|
102 |
+
tornado==6.4.2
|
103 |
+
tqdm==4.67.0
|
104 |
+
transformers==4.46.3
|
105 |
+
typer==0.13.1
|
106 |
+
typing_extensions==4.12.2
|
107 |
+
tzdata==2024.2
|
108 |
+
urllib3==2.2.3
|
109 |
+
uvicorn==0.32.1
|
110 |
+
websockets==12.0
|
111 |
+
xxhash==3.5.0
|
112 |
+
yarl==1.17.2
|
script/create_vector_base.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss, json
|
2 |
+
from datasets import load_dataset
|
3 |
+
from utils.embedding import Embeddings
|
4 |
+
|
5 |
+
def get_chunkes(docs, size):
|
6 |
+
chunked_texts, metadata= [], []
|
7 |
+
|
8 |
+
for _, text in enumerate(docs):
|
9 |
+
for i in range(0, len(text['abstract']), size):
|
10 |
+
chunk = text['abstract'][i:i + size]
|
11 |
+
|
12 |
+
chunked_texts.append(chunk)
|
13 |
+
metadata.append({'title': text['title'], 'text': chunk})
|
14 |
+
|
15 |
+
return chunked_texts, metadata
|
16 |
+
|
17 |
+
|
18 |
+
def create_base(docs, model: Embeddings):
|
19 |
+
chunks, metadata = get_chunkes(docs, 256)
|
20 |
+
dimension = 384
|
21 |
+
embeddings = model.get_embeddings(chunks)
|
22 |
+
index = faiss.IndexFlatL2(dimension)
|
23 |
+
index.add(embeddings)
|
24 |
+
|
25 |
+
return index, metadata
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
data = load_dataset("aalksii/ml-arxiv-papers")
|
30 |
+
articles = data['train'].select(range(10000))
|
31 |
+
embed_model = Embeddings()
|
32 |
+
|
33 |
+
vector_base, metadata = create_base(articles, embed_model)
|
34 |
+
faiss.write_index(vector_base, "faiss_index.faiss")
|
35 |
+
|
36 |
+
with open("../metadata.json", "w") as f:
|
37 |
+
json.dump(metadata, f, indent=4)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
script/preprocessing_text.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk, json
|
2 |
+
from nltk.stem import PorterStemmer
|
3 |
+
from nltk.corpus import stopwords
|
4 |
+
# from app.main import get_metadata
|
5 |
+
from config import config
|
6 |
+
|
7 |
+
|
8 |
+
class Preprocessor:
|
9 |
+
def _tokenize(self, text):
|
10 |
+
text = text.lower().split(' ')
|
11 |
+
return text
|
12 |
+
|
13 |
+
def preprocessing_text(self, doc):
|
14 |
+
tokens = self._tokenize(doc)
|
15 |
+
|
16 |
+
nltk.download('stopwords')
|
17 |
+
stop_words = set(stopwords.words('english'))
|
18 |
+
filtered_tokens = [token for token in tokens if not token in stop_words]
|
19 |
+
|
20 |
+
stemmer = PorterStemmer()
|
21 |
+
stemmed_tokes = [stemmer.stem(filtered_token) for filtered_token in filtered_tokens]
|
22 |
+
preprocess_text = " ".join(stemmed_tokes)
|
23 |
+
return preprocess_text
|
24 |
+
|
25 |
+
def _save(self, docs):
|
26 |
+
with open("../preprocessing_text.json", "w") as f:
|
27 |
+
json.dump(docs, f, indent=4)
|
28 |
+
|
29 |
+
def preprocessing(self, docs):
|
30 |
+
preprocessed_docs = [self.preprocessing_text(doc) for doc in docs]
|
31 |
+
self._save(preprocessed_docs)
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
texts, _ = get_metadata(config.PATH_METADATA)
|
35 |
+
preprocessor = Preprocessor()
|
36 |
+
preprocessor.preprocessing(texts)
|
utils/embedding.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
|
3 |
+
|
4 |
+
class Embeddings:
|
5 |
+
def __init__(self, model_name: str = 'BAAI/bge-small-en-v1.5'):
|
6 |
+
self.model = SentenceTransformer(model_name, trust_remote_code=True, revision="main")
|
7 |
+
|
8 |
+
def get_query_embedding(self, query):
|
9 |
+
query_embed = self.model.encode([query], normalize_embeddings=True)
|
10 |
+
return query_embed
|
11 |
+
|
12 |
+
def get_embeddings(self, texts):
|
13 |
+
embeddings = self.model.encode(texts, normalize_embeddings=True)
|
14 |
+
return embeddings
|
utils/llm.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
load_dotenv()
|
5 |
+
|
6 |
+
class LLM:
|
7 |
+
def __init__(self, url, api_key):
|
8 |
+
self.endpoint = url
|
9 |
+
self.api_key = api_key
|
10 |
+
|
11 |
+
def generate_response(self, prompt):
|
12 |
+
headers = {
|
13 |
+
"Content-Type": "application/json",
|
14 |
+
"api-key": self.api_key,
|
15 |
+
}
|
16 |
+
|
17 |
+
data = {
|
18 |
+
"messages": [{"role": "user", "content": prompt}],
|
19 |
+
"max_tokens": 1500,
|
20 |
+
"temperature": 0.5,
|
21 |
+
}
|
22 |
+
|
23 |
+
response = requests.post(self.endpoint, headers=headers, json=data)
|
24 |
+
|
25 |
+
if response.status_code == 200:
|
26 |
+
return response.json()["choices"][0]["message"]["content"]
|
27 |
+
else:
|
28 |
+
return ValueError(response.text)
|
utils/vector_base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss, json
|
2 |
+
from script.preprocessing_text import Preprocessor
|
3 |
+
from rank_bm25 import BM25Okapi
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class KnowledgeBase:
|
7 |
+
def __init__(self, faiss_path, preprocessing_path) -> None:
|
8 |
+
self.BM25_model = BM25Okapi(self._load(preprocessing_path))
|
9 |
+
self.vector_base = faiss.read_index(faiss_path)
|
10 |
+
|
11 |
+
def _load(self, path):
|
12 |
+
with open(path, 'rb') as file:
|
13 |
+
data = json.load(file)
|
14 |
+
return data
|
15 |
+
|
16 |
+
def search_by_BM25(self, query, k=5):
|
17 |
+
preprocessor = Preprocessor()
|
18 |
+
prep_query = preprocessor.preprocessing_text(query)
|
19 |
+
doc_scores = self.BM25_model.get_scores(prep_query)
|
20 |
+
sorted_docs = np.argsort(-doc_scores)
|
21 |
+
return sorted_docs[:k].tolist()
|
22 |
+
|
23 |
+
def search_by_embedding(self, embedding, k):
|
24 |
+
_, indexes = self.vector_base.search(embedding, k)
|
25 |
+
return indexes
|
26 |
+
|
27 |
+
|