[email protected] commited on
Commit
e496d26
·
1 Parent(s): ff0c986

Add some FAISS reinitilisation strategy

Browse files
Files changed (1) hide show
  1. rag.py +33 -8
rag.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  from dotenv import load_dotenv
4
  from langchain_community.vectorstores import FAISS
@@ -10,6 +11,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain.schema.runnable import RunnablePassthrough
11
  from langchain.prompts import PromptTemplate
12
  from langchain_community.vectorstores.utils import filter_complex_metadata
 
13
 
14
  from util import getYamlConfig
15
 
@@ -19,16 +21,15 @@ load_dotenv()
19
  env_api_key = os.environ.get("MISTRAL_API_KEY")
20
 
21
  class Rag:
22
- document_vector_store = None
23
- retriever = None
24
- chain = None
25
- readableModelName = ""
26
- documents = []
27
 
28
  def __init__(self, vectore_store=None):
 
 
 
 
 
 
29
 
30
- print(self.document_vector_store)
31
- # self.model = ChatMistralAI(model=llm_model)
32
  self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
33
 
34
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=300, separators="\n\n", length_function=len)
@@ -36,8 +37,24 @@ class Rag:
36
  base_template = getYamlConfig()['prompt_template']
37
  self.prompt = PromptTemplate.from_template(base_template)
38
 
 
39
  self.vector_store = vectore_store
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def setModel(self, model, readableModelName = ""):
42
  self.model = model
43
  self.readableModelName = readableModelName
@@ -66,9 +83,17 @@ class Rag:
66
  docs = PyPDFLoader(file_path=pdf_file_path).load()
67
 
68
  chunks = self.text_splitter.split_documents(docs)
69
-
70
  self.documents.extend(chunks)
 
 
 
 
 
 
71
  self.document_vector_store = FAISS.from_documents(self.documents, self.embedding)
 
 
 
72
 
73
 
74
  self.retriever = self.document_vector_store.as_retriever(
 
1
  import os
2
+ import faiss
3
 
4
  from dotenv import load_dotenv
5
  from langchain_community.vectorstores import FAISS
 
11
  from langchain.schema.runnable import RunnablePassthrough
12
  from langchain.prompts import PromptTemplate
13
  from langchain_community.vectorstores.utils import filter_complex_metadata
14
+ from langchain_core.documents import Document
15
 
16
  from util import getYamlConfig
17
 
 
21
  env_api_key = os.environ.get("MISTRAL_API_KEY")
22
 
23
  class Rag:
 
 
 
 
 
24
 
25
  def __init__(self, vectore_store=None):
26
+ print("Nouvelle instance de Rag créée")
27
+ self.document_vector_store = None
28
+ self.retriever = None
29
+ self.chain = None
30
+ self.readableModelName = ""
31
+ self.documents = []
32
 
 
 
33
  self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
34
 
35
  self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=300, separators="\n\n", length_function=len)
 
37
  base_template = getYamlConfig()['prompt_template']
38
  self.prompt = PromptTemplate.from_template(base_template)
39
 
40
+ self.reset_faiss_store()
41
  self.vector_store = vectore_store
42
 
43
+
44
+ def reset_faiss_store(self):
45
+ """ Initialise un FAISS vide avec la bonne dimension """
46
+
47
+ # Ajouter un document à l'index FAISS
48
+ docs = [ Document(page_content=" ") ]
49
+ self.document_vector_store = FAISS.from_documents(docs, self.embedding)
50
+
51
+ # Vider l'index FAISS
52
+ self.document_vector_store.index.reset()
53
+
54
+ # Vérifier que l'index est vidé
55
+ print(f"Nombre de vecteurs après reset: {self.document_vector_store.index.ntotal}")
56
+
57
+
58
  def setModel(self, model, readableModelName = ""):
59
  self.model = model
60
  self.readableModelName = readableModelName
 
83
  docs = PyPDFLoader(file_path=pdf_file_path).load()
84
 
85
  chunks = self.text_splitter.split_documents(docs)
 
86
  self.documents.extend(chunks)
87
+
88
+ if self.document_vector_store:
89
+ print(f"Nombre de documents indexés dans FAISS : {self.document_vector_store.index.ntotal}")
90
+ else:
91
+ print("No document_vectore")
92
+
93
  self.document_vector_store = FAISS.from_documents(self.documents, self.embedding)
94
+ print(f"Après ingestion, FAISS contient {self.document_vector_store.index.ntotal} documents.")
95
+
96
+
97
 
98
 
99
  self.retriever = self.document_vector_store.as_retriever(