anekameni commited on
Commit
56d99ec
·
1 Parent(s): f6d49e1

Refactor RAG system query methods; update descriptions and improve logging for better clarity

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ class ChatInterface:
12
 
13
  def respond(self, message: str, history: List[List[str]]):
14
  result = ""
15
- for text in self.rag_system.query_iter(message, history):
16
  result += text
17
  yield result
18
  return result
@@ -21,7 +21,7 @@ class ChatInterface:
21
  chat_interface = gr.ChatInterface(
22
  fn=self.respond,
23
  title="Medivocate",
24
- description="Medivocate is an AI-driven platform leveraging Retrieval-Augmented Generation (RAG) powered by African history. It processes and classifies document pages with precision to provide trustworthy, personalized guidance, fostering accurate knowledge and equitable access to historical insights.",
25
  # retry_btn=None,
26
  # undo_btn=None,
27
  # clear_btn="Clear",
 
12
 
13
  def respond(self, message: str, history: List[List[str]]):
14
  result = ""
15
+ for text in self.rag_system.query(message, history):
16
  result += text
17
  yield result
18
  return result
 
21
  chat_interface = gr.ChatInterface(
22
  fn=self.respond,
23
  title="Medivocate",
24
+ description="Medivocate est une application qui offre des informations claires et structurées sur l'histoire de l'Afrique et sa médecine traditionnelle, en s'appuyant exclusivement sur un contexte issu de documentaires sur l'histoire du continent africain.",
25
  # retry_btn=None,
26
  # undo_btn=None,
27
  # clear_btn="Clear",
src/rag_pipeline/prompts.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts.chat import (
2
+ ChatPromptTemplate,
3
+ HumanMessagePromptTemplate,
4
+ MessagesPlaceholder,
5
+ SystemMessagePromptTemplate,
6
+ )
7
+
8
+ system_template = """
9
+ Vous êtes un assistant IA qui fournit des informations sur l'histoire de l'Afrique et la médecine traditionnelle africaine. Vous recevez une question et fournissez une réponse claire et structurée. Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.
10
+
11
+ Utilisez uniquement les éléments de contexte suivants pour répondre à la question de l'utilisateur. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
12
+
13
+ Si la question posée est dans une langue parlée en Afrique ou demande une traduction dans une de ces langues, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
14
+
15
+ Si vous connaissez la réponse à la question mais que cette réponse ne provient pas du contexte ou n'est pas relative à l'histoire africaine ou à la médecine traditionnelle, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
16
+
17
+ -----------------
18
+ {context}
19
+ """
20
+
21
+ messages = [
22
+ MessagesPlaceholder(variable_name="chat_history"),
23
+ SystemMessagePromptTemplate.from_template(system_template),
24
+ HumanMessagePromptTemplate.from_template("{input}"),
25
+ ]
26
+ CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
27
+
28
+
29
+ contextualize_q_system_prompt = (
30
+ "Étant donné un historique de conversation et la dernière question de l'utilisateur "
31
+ "qui pourrait faire référence au contexte dans l'historique de conversation, "
32
+ "formulez une question autonome qui peut être comprise "
33
+ "sans l'historique de conversation. NE répondez PAS à la question, reformulez-la "
34
+ "si nécessaire, sinon retournez-la telle quelle."
35
+ )
36
+
37
+ CONTEXTUEL_QUERY_PROMPT = ChatPromptTemplate.from_messages(
38
+ [
39
+ SystemMessagePromptTemplate.from_template(contextualize_q_system_prompt),
40
+ MessagesPlaceholder("chat_history"),
41
+ HumanMessagePromptTemplate.from_template("{input}"),
42
+ ]
43
+ )
src/rag_pipeline/rag_system.py CHANGED
@@ -1,13 +1,18 @@
1
- import os
2
  from typing import List, Optional
3
 
4
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
 
 
 
 
 
5
  from langchain.chains.retrieval import create_retrieval_chain
6
- from langchain.prompts import PromptTemplate
7
- from langchain_core.runnables import Runnable
8
 
9
- from ..utilities.llm_models import get_llm_model_chat, get_llm_model_embedding
10
  from ..vector_store.vector_store import VectorStoreManager
 
11
 
12
 
13
  class RAGSystem:
@@ -19,19 +24,16 @@ class RAGSystem:
19
  top_k_documents=5,
20
  ):
21
  self.top_k_documents = top_k_documents
22
- self.embeddings = self._get_embeddings()
23
  self.llm = self._get_llm()
24
- self.chain: Optional[Runnable] = None
25
  self.vector_store_management = VectorStoreManager(
26
  docs_dir, persist_directory_dir, batch_size
27
  )
28
 
29
- def _get_llm(self):
30
- return get_llm_model_chat("GROQ", temperature=0.1, max_tokens=500)
31
-
32
- def _get_embeddings(self):
33
- """Initialize embeddings based on environment configuration"""
34
- return get_llm_model_embedding()
35
 
36
  def initialize_vector_store(self, documents: List = None):
37
  """Initialize or load the vector store"""
@@ -40,79 +42,28 @@ class RAGSystem:
40
  def setup_rag_chain(self):
41
  if self.chain is not None:
42
  return
43
- """Set up the RAG chain with custom prompt"""
44
- prompt_template = """Inspirez vous du contexte fourni ci-dessous pour répondre à la question qui suit de la manière la plus précise possible.
45
- Si la réponse ne peut pas être déterminée à partir du contexte, évitez d'inventer des informations.
46
- L'historique ici fait référence aux précédents échanges avec un utilisateur, tu devrais l'ignore si aucun rapport avec la question posée.
47
- Tes réponses doivent être naturelles sous forme de faits, au lieu de faire mention du fait que réponds en fonction d'un contexte.
48
-
49
- **Historique** :
50
- {history}
51
-
52
- **Contexte** :
53
- {context}
54
-
55
- **Question** :
56
- {input}
57
-
58
- Réponse (Vous devez répondre dans la même langue que celle de la question) :"""
59
-
60
- prompt = PromptTemplate(
61
- template=prompt_template, input_variables=["context", "input", "history"]
62
- )
63
  retriever = self.vector_store_management.vector_store.as_retriever(
64
  search_kwargs={"k": self.top_k_documents}
65
  )
66
- question_answer_chain = create_stuff_documents_chain(self.llm, prompt)
67
-
68
- self.chain = create_retrieval_chain(retriever, question_answer_chain)
69
-
70
- def query(self, question: str, history: List[tuple[str]] = []):
71
- """Query the RAG system"""
72
- if not self.vector_store_management.vector_store:
73
- self.initialize_vector_store()
74
 
75
- self.setup_rag_chain()
76
-
77
- # Format history as a single string of interactions
78
- history_text = "\n".join(
79
- [f"Utilisateur: {user}\nAssistant: {assistant}" for user, assistant in history]
80
  )
 
 
 
 
 
 
81
 
82
- response = self.chain.invoke({"input": question, "history": history_text})
83
-
84
- return {
85
- "answer": response["answer"],
86
- "source_documents": [doc.page_content for doc in response["context"]],
87
- }
88
-
89
- def query_iter(self, question: str, history: List[tuple[str]] = []):
90
  """Query the RAG system"""
91
  if not self.vector_store_management.vector_store:
92
  self.initialize_vector_store()
93
 
94
  self.setup_rag_chain()
95
 
96
- # Format history as a single string of interactions
97
- history_text = "\n".join(
98
- [f"Utilisateur: {user}\nAssistant: {assistant}" for user, assistant in history]
99
- )
100
-
101
- for token in self.chain.stream({"input": question, "history": history_text}):
102
  if "answer" in token:
103
  yield token["answer"]
104
-
105
-
106
- if __name__ == "__main__":
107
- from glob import glob
108
-
109
- docs_dir = "data/docs"
110
- persist_directory_dir = "data/chroma_db"
111
- batch_size = 64
112
-
113
- # Initialize RAG system
114
- rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)
115
-
116
- rag.initialize_vector_store() # vector store initialized
117
-
118
- print(rag.query("Quand a eu lieu la traite négrière ?"))
 
1
+ import logging
2
  from typing import List, Optional
3
 
4
  from langchain.chains.combine_documents import create_stuff_documents_chain
5
+ from langchain.chains.conversational_retrieval.base import (
6
+ ConversationalRetrievalChain,
7
+ )
8
+ from langchain.chains.history_aware_retriever import (
9
+ create_history_aware_retriever,
10
+ )
11
  from langchain.chains.retrieval import create_retrieval_chain
 
 
12
 
13
+ from ..utilities.llm_models import get_llm_model_chat
14
  from ..vector_store.vector_store import VectorStoreManager
15
+ from .prompts import CHAT_PROMPT, CONTEXTUEL_QUERY_PROMPT
16
 
17
 
18
  class RAGSystem:
 
24
  top_k_documents=5,
25
  ):
26
  self.top_k_documents = top_k_documents
 
27
  self.llm = self._get_llm()
28
+ self.chain: Optional[ConversationalRetrievalChain] = None
29
  self.vector_store_management = VectorStoreManager(
30
  docs_dir, persist_directory_dir, batch_size
31
  )
32
 
33
+ def _get_llm(
34
+ self,
35
+ ):
36
+ return get_llm_model_chat(temperature=0.1, max_tokens=1000)
 
 
37
 
38
  def initialize_vector_store(self, documents: List = None):
39
  """Initialize or load the vector store"""
 
42
  def setup_rag_chain(self):
43
  if self.chain is not None:
44
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  retriever = self.vector_store_management.vector_store.as_retriever(
46
  search_kwargs={"k": self.top_k_documents}
47
  )
 
 
 
 
 
 
 
 
48
 
49
+ # Contextualize question
50
+ history_aware_retriever = create_history_aware_retriever(
51
+ self.llm, retriever, CONTEXTUEL_QUERY_PROMPT
 
 
52
  )
53
+ question_answer_chain = create_stuff_documents_chain(self.llm, CHAT_PROMPT)
54
+ self.chain = create_retrieval_chain(
55
+ history_aware_retriever, question_answer_chain
56
+ )
57
+ logging.info("RAG chain setup complete" + str(self.chain))
58
+ return self.chain
59
 
60
+ def query(self, question: str, history: list = []):
 
 
 
 
 
 
 
61
  """Query the RAG system"""
62
  if not self.vector_store_management.vector_store:
63
  self.initialize_vector_store()
64
 
65
  self.setup_rag_chain()
66
 
67
+ for token in self.chain.stream({"input": question, "chat_history": history}):
 
 
 
 
 
68
  if "answer" in token:
69
  yield token["answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utilities/llm_models.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  from enum import Enum
3
- from typing import Union
4
 
5
  from langchain_groq import ChatGroq
6
- from langchain_ollama import ChatOllama, OllamaEmbeddings
7
  from langchain_huggingface import HuggingFaceEmbeddings
 
8
 
9
 
10
  class LLMModel(Enum):
@@ -12,13 +11,11 @@ class LLMModel(Enum):
12
  GROQ = ChatGroq
13
 
14
 
15
- def get_llm_model_chat(
16
- model_type: Union[str, LLMModel], temperature=0, max_tokens=None
17
- ):
18
- if isinstance(model_type, str):
19
- model_type = LLMModel[model_type.upper()]
20
- if model_type == LLMModel.OLLAMA:
21
- return model_type.value(
22
  model=os.getenv("OLLAMA_MODEL"),
23
  temperature=temperature,
24
  max_tokens=max_tokens,
@@ -30,7 +27,7 @@ def get_llm_model_chat(
30
  }
31
  },
32
  )
33
- return model_type.value(
34
  model=os.getenv("GROQ_MODEL_NAME"),
35
  temperature=temperature,
36
  max_tokens=max_tokens,
@@ -38,7 +35,7 @@ def get_llm_model_chat(
38
 
39
 
40
  def get_llm_model_embedding():
41
- if os.getenv("USE_HF"):
42
  return HuggingFaceEmbeddings(
43
  model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
44
  model_kwargs={"device": "cpu"},
 
1
  import os
2
  from enum import Enum
 
3
 
4
  from langchain_groq import ChatGroq
 
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_ollama import ChatOllama, OllamaEmbeddings
7
 
8
 
9
  class LLMModel(Enum):
 
11
  GROQ = ChatGroq
12
 
13
 
14
+ def get_llm_model_chat(temperature=0.01, max_tokens=None):
15
+ if str(os.getenv("USE_OLLAMA_CHAT")) == "1" and "localhost" not in str(
16
+ os.getenv("OLLAMA_HOST")
17
+ ):
18
+ return ChatOllama(
 
 
19
  model=os.getenv("OLLAMA_MODEL"),
20
  temperature=temperature,
21
  max_tokens=max_tokens,
 
27
  }
28
  },
29
  )
30
+ return ChatGroq(
31
  model=os.getenv("GROQ_MODEL_NAME"),
32
  temperature=temperature,
33
  max_tokens=max_tokens,
 
35
 
36
 
37
  def get_llm_model_embedding():
38
+ if str(os.getenv("USE_HF_EMBEDDING")) == "1":
39
  return HuggingFaceEmbeddings(
40
  model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
41
  model_kwargs={"device": "cpu"},
src/vector_store/vector_store.py CHANGED
@@ -1,33 +1,11 @@
1
- import json
2
  import os
3
- from concurrent.futures import ThreadPoolExecutor
4
- from glob import glob
5
  from typing import List
6
 
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_chroma import Chroma
9
- from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
- from langchain_core.documents import Document
11
  from tqdm import tqdm
12
 
13
  from ..utilities.llm_models import get_llm_model_embedding
14
 
15
-
16
- def sanitize_metadata(metadata):
17
- sanitized = {}
18
- for key, value in metadata.items():
19
- if isinstance(value, list):
20
- # Convert lists to comma-separated strings or handle appropriately
21
- sanitized[key] = ", ".join(value)
22
- elif isinstance(value, (str, int, float, bool)):
23
- sanitized[key] = value
24
- else:
25
- raise ValueError(
26
- f"Unsupported metadata type for key '{key}': {type(value)}"
27
- )
28
- return sanitized
29
-
30
-
31
  class VectorStoreManager:
32
  def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
33
  self.embeddings = get_llm_model_embedding()
 
 
1
  import os
 
 
2
  from typing import List
3
 
 
4
  from langchain_chroma import Chroma
 
 
5
  from tqdm import tqdm
6
 
7
  from ..utilities.llm_models import get_llm_model_embedding
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class VectorStoreManager:
10
  def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
11
  self.embeddings = get_llm_model_embedding()