Xalt8 commited on
Commit
524e9f1
·
1 Parent(s): 3d33782

reranking working example

Browse files
rag_app/metadata.ipynb DELETED
@@ -1,170 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from pathlib import Path\n",
10
- "from langchain_community.vectorstores import FAISS\n",
11
- "from dotenv import load_dotenv\n",
12
- "import os\n",
13
- "from langchain_huggingface import HuggingFaceEmbeddings"
14
- ]
15
- },
16
- {
17
- "cell_type": "code",
18
- "execution_count": 3,
19
- "metadata": {},
20
- "outputs": [
21
- {
22
- "data": {
23
- "text/plain": [
24
- "True"
25
- ]
26
- },
27
- "execution_count": 3,
28
- "metadata": {},
29
- "output_type": "execute_result"
30
- }
31
- ],
32
- "source": [
33
- "load_dotenv()"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": 5,
39
- "metadata": {},
40
- "outputs": [],
41
- "source": [
42
- "HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFCEHUB_API_TOKEN')\n",
43
- "EMBEDDING_MODEL = os.getenv(\"EMBEDDING_MODEL\")"
44
- ]
45
- },
46
- {
47
- "cell_type": "code",
48
- "execution_count": null,
49
- "metadata": {},
50
- "outputs": [],
51
- "source": [
52
- "embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 7,
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "folder_path = Path('..') / \"vectorstore/faiss-insurance-agent-500\"\n",
62
- "faissdb = FAISS.load_local(folder_path=str(folder_path.resolve()),\n",
63
- " embeddings=embeddings,\n",
64
- " allow_dangerous_deserialization=True) "
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": 24,
70
- "metadata": {},
71
- "outputs": [
72
- {
73
- "name": "stdout",
74
- "output_type": "stream",
75
- "text": [
76
- "Content: Die private Haftpflichtversicherung...\n",
77
- "Metadata: {'source': 'https://www.wuerttembergische.de/versicherungen/stadt/wuppertal/', 'content_type': 'text/html; charset=UTF-8', 'title': 'Versicherung in Wuppertal', 'description': 'Ihre Versicherungsagentur in Wuppertal: Kommen Sie zur Württembergischen Versicherung und profitieren Sie von einer persönlichen Beratung und ausgezeichnetem Service. ', 'language': 'de'}\n",
78
- "---\n",
79
- "Content: Haftpflichtversicherung...\n",
80
- "Metadata: {'source': 'https://www.wuerttembergische.de/wohnen/hausratversicherung/sengschaden/', 'content_type': 'text/html; charset=UTF-8', 'title': 'Sengschäden: So schützt Sie Ihre Hausrat- und Wohngebäudeversicherung', 'description': 'Deckt Ihre Hausratversicherung Sengschäden ab? Finden Sie heraus, wie Sie bei Schäden durch Glut oder Hitze ohne direktes Feuer geschützt sind.\\n', 'language': 'de'}\n",
81
- "---\n",
82
- "Content: Die Leistungen unserer privaten Haftpflichtversich...\n",
83
- "Metadata: {'source': 'https://www.wuerttembergische.de/existenz/private-haftpflichtversicherung/drohnen-versichern/', 'content_type': 'text/html; charset=UTF-8', 'title': 'Drohnen über die private Haftpflicht versichern', 'description': 'Müssen Drohnen versichert sein? Welcher Tarif ist der beste? Erfahren Sie hier die wichtigsten Informationen rund ums Thema Drohne versichern.', 'language': 'de'}\n",
84
- "---\n",
85
- "Content: Das kann ohne private Haftpflichtversicherung pass...\n",
86
- "Metadata: {'source': 'https://www.wuerttembergische.de/existenz/private-haftpflichtversicherung/pflicht/', 'content_type': 'text/html; charset=UTF-8', 'title': 'Ist die private Haftpflichtversicherung Pflicht oder freiwillig?', 'description': 'Ist eine Privathaftpflichtversicherung gesetzlich vorgeschrieben? Welche Haftpflichtversicherung Pflicht sind und welche freiwillig - das erfahren Sie hier.', 'language': 'de'}\n",
87
- "---\n",
88
- "Content: Private Haftpflicht: keine Pflichtversicherung\n",
89
- "Fre...\n",
90
- "Metadata: {'source': 'https://www.wuerttembergische.de/existenz/private-haftpflichtversicherung/pflicht/', 'content_type': 'text/html; charset=UTF-8', 'title': 'Ist die private Haftpflichtversicherung Pflicht oder freiwillig?', 'description': 'Ist eine Privathaftpflichtversicherung gesetzlich vorgeschrieben? Welche Haftpflichtversicherung Pflicht sind und welche freiwillig - das erfahren Sie hier.', 'language': 'de'}\n",
91
- "---\n"
92
- ]
93
- }
94
- ],
95
- "source": [
96
- "# Perform a similarity search with an empty query to get random documents\n",
97
- "documents = faissdb.similarity_search(\"Private Haftpflicht­versicherung\", k=5)\n",
98
- "\n",
99
- "for doc in documents:\n",
100
- " print(f\"Content: {doc.page_content[:50]}...\") # Print first 50 chars of content\n",
101
- " print(f\"Metadata: {doc.metadata}\")\n",
102
- " print(\"---\")"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": 19,
108
- "metadata": {},
109
- "outputs": [
110
- {
111
- "name": "stdout",
112
- "output_type": "stream",
113
- "text": [
114
- "Number of entries in the database: 62496\n"
115
- ]
116
- }
117
- ],
118
- "source": [
119
- "num_entries = len(faissdb.index_to_docstore_id)\n",
120
- "print(f\"Number of entries in the database: {num_entries}\")"
121
- ]
122
- },
123
- {
124
- "cell_type": "code",
125
- "execution_count": 20,
126
- "metadata": {},
127
- "outputs": [
128
- {
129
- "name": "stdout",
130
- "output_type": "stream",
131
- "text": [
132
- "Number of entries in the database: 62496\n"
133
- ]
134
- }
135
- ],
136
- "source": [
137
- "num_entries = faissdb.index.ntotal\n",
138
- "print(f\"Number of entries in the database: {num_entries}\")"
139
- ]
140
- },
141
- {
142
- "cell_type": "code",
143
- "execution_count": null,
144
- "metadata": {},
145
- "outputs": [],
146
- "source": []
147
- }
148
- ],
149
- "metadata": {
150
- "kernelspec": {
151
- "display_name": "venv",
152
- "language": "python",
153
- "name": "python3"
154
- },
155
- "language_info": {
156
- "codemirror_mode": {
157
- "name": "ipython",
158
- "version": 3
159
- },
160
- "file_extension": ".py",
161
- "mimetype": "text/x-python",
162
- "name": "python",
163
- "nbconvert_exporter": "python",
164
- "pygments_lexer": "ipython3",
165
- "version": "3.11.4"
166
- }
167
- },
168
- "nbformat": 4,
169
- "nbformat_minor": 2
170
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/metadata_filtering.py DELETED
@@ -1,29 +0,0 @@
1
- from pathlib import Path
2
- from langchain_community.vectorstores import FAISS
3
- from dotenv import load_dotenv
4
- import os
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
-
7
-
8
- load_dotenv(".env")
9
-
10
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
11
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
12
-
13
-
14
- if __name__ == "__main__":
15
-
16
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
17
-
18
- folder_path = Path('..') / "vectorstore/faiss-insurance-agent-500"
19
-
20
- print(f'{Path(folder_path).exists() = }')
21
-
22
- faissdb = FAISS.load_local(folder_path=str(folder_path.resolve()),
23
- embeddings=embeddings,
24
- allow_dangerous_deserialization=True)
25
-
26
- documents = faissdb.get(list(range(5)))
27
-
28
- for doc in documents:
29
- print(f"Metadata: {doc.metadata}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/reranking.py CHANGED
@@ -4,20 +4,77 @@ from langchain_community.vectorstores import FAISS
4
  from dotenv import load_dotenv
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
 
7
 
8
  load_dotenv()
9
 
10
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
11
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
12
 
13
- embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=HUGGINGFACEHUB_API_TOKEN,
14
- model_name=EMBEDDING_MODEL)
 
 
 
 
15
 
16
- path_to_vector_db = Path("..")/'vectorstore'/'faiss-insurance-agent-500'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
 
 
 
 
 
19
 
20
- # retreiver = get_db_retriever(vector_db=Path("..")/)
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if __name__ == "__main__":
23
- print(path_to_vector_db.exists())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from dotenv import load_dotenv
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
+ import requests
8
 
9
  load_dotenv()
10
 
 
 
11
 
12
+ def get_reranked_docs(query:str,
13
+ path_to_db:str,
14
+ embedding_model:str,
15
+ hf_api_key:str,
16
+ num_docs:int=5) -> list:
17
+ """ Re-ranks the similarity search results and returns top-k highest ranked docs
18
 
19
+ Args:
20
+ query (str): The search query
21
+ path_to_db (str): Path to the vectorstore database
22
+ embedding_model (str): Embedding model used in the vector store
23
+ num_docs (int): Number of documents to return
24
+
25
+ Returns: A list of documents with the highest rank
26
+ """
27
+ assert num_docs <= 10, "num_docs should be less than similarity search results"
28
+
29
+ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
30
+ model_name=embedding_model)
31
+ # Load the vectorstore database
32
+ db = FAISS.load_local(folder_path=path_to_db,
33
+ embeddings=embeddings,
34
+ allow_dangerous_deserialization=True)
35
+
36
+ # Get 10 documents based on similarity search
37
+ docs = db.similarity_search(query=query, k=10)
38
 
39
+ # Add the page_content, description and title together
40
+ passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "")
41
+ for doc in docs]
42
+
43
+ # Prepare the payload
44
+ inputs = [{"text": query, "text_pair": passage} for passage in passages]
45
 
46
+ API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking"
47
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
48
 
49
+ response = requests.post(API_URL, headers=headers, json=inputs)
50
+ scores = response.json()
51
+
52
+ try:
53
+ relevance_scores = [item[1]['score'] for item in scores]
54
+ except ValueError as e:
55
+ print('Could not get the relevance_scores -> something might be wrong with the json output')
56
+ return
57
+
58
+ if relevance_scores:
59
+ ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
60
+ top_k_results = ranked_results[:num_docs]
61
+ return [doc for doc, _, _ in top_k_results]
62
+
63
+
64
  if __name__ == "__main__":
65
+
66
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
67
+ EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
68
+
69
+ path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
70
+
71
+ query = "Ich möchte wissen, ob ich meine geriatrische Haustier-Eidechse versichern kann"
72
+
73
+ top_5_docs = get_reranked_docs(query=query,
74
+ path_to_db=path_to_vector_db,
75
+ embedding_model=EMBEDDING_MODEL,
76
+ hf_api_key=HUGGINGFACEHUB_API_TOKEN,
77
+ num_docs=5)
78
+
79
+ for i, doc in enumerate(top_5_docs):
80
+ print(f"{i}: {doc}\n")