Spaces:
Running
Running
anekameni
commited on
Commit
·
56d99ec
1
Parent(s):
f6d49e1
Refactor RAG system query methods; update descriptions and improve logging for better clarity
Browse files- app.py +2 -2
- src/rag_pipeline/prompts.py +43 -0
- src/rag_pipeline/rag_system.py +25 -74
- src/utilities/llm_models.py +8 -11
- src/vector_store/vector_store.py +0 -22
app.py
CHANGED
@@ -12,7 +12,7 @@ class ChatInterface:
|
|
12 |
|
13 |
def respond(self, message: str, history: List[List[str]]):
|
14 |
result = ""
|
15 |
-
for text in self.rag_system.
|
16 |
result += text
|
17 |
yield result
|
18 |
return result
|
@@ -21,7 +21,7 @@ class ChatInterface:
|
|
21 |
chat_interface = gr.ChatInterface(
|
22 |
fn=self.respond,
|
23 |
title="Medivocate",
|
24 |
-
description="Medivocate
|
25 |
# retry_btn=None,
|
26 |
# undo_btn=None,
|
27 |
# clear_btn="Clear",
|
|
|
12 |
|
13 |
def respond(self, message: str, history: List[List[str]]):
|
14 |
result = ""
|
15 |
+
for text in self.rag_system.query(message, history):
|
16 |
result += text
|
17 |
yield result
|
18 |
return result
|
|
|
21 |
chat_interface = gr.ChatInterface(
|
22 |
fn=self.respond,
|
23 |
title="Medivocate",
|
24 |
+
description="Medivocate est une application qui offre des informations claires et structurées sur l'histoire de l'Afrique et sa médecine traditionnelle, en s'appuyant exclusivement sur un contexte issu de documentaires sur l'histoire du continent africain.",
|
25 |
# retry_btn=None,
|
26 |
# undo_btn=None,
|
27 |
# clear_btn="Clear",
|
src/rag_pipeline/prompts.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts.chat import (
|
2 |
+
ChatPromptTemplate,
|
3 |
+
HumanMessagePromptTemplate,
|
4 |
+
MessagesPlaceholder,
|
5 |
+
SystemMessagePromptTemplate,
|
6 |
+
)
|
7 |
+
|
8 |
+
system_template = """
|
9 |
+
Vous êtes un assistant IA qui fournit des informations sur l'histoire de l'Afrique et la médecine traditionnelle africaine. Vous recevez une question et fournissez une réponse claire et structurée. Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.
|
10 |
+
|
11 |
+
Utilisez uniquement les éléments de contexte suivants pour répondre à la question de l'utilisateur. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
|
12 |
+
|
13 |
+
Si la question posée est dans une langue parlée en Afrique ou demande une traduction dans une de ces langues, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
|
14 |
+
|
15 |
+
Si vous connaissez la réponse à la question mais que cette réponse ne provient pas du contexte ou n'est pas relative à l'histoire africaine ou à la médecine traditionnelle, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
|
16 |
+
|
17 |
+
-----------------
|
18 |
+
{context}
|
19 |
+
"""
|
20 |
+
|
21 |
+
messages = [
|
22 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
23 |
+
SystemMessagePromptTemplate.from_template(system_template),
|
24 |
+
HumanMessagePromptTemplate.from_template("{input}"),
|
25 |
+
]
|
26 |
+
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
27 |
+
|
28 |
+
|
29 |
+
contextualize_q_system_prompt = (
|
30 |
+
"Étant donné un historique de conversation et la dernière question de l'utilisateur "
|
31 |
+
"qui pourrait faire référence au contexte dans l'historique de conversation, "
|
32 |
+
"formulez une question autonome qui peut être comprise "
|
33 |
+
"sans l'historique de conversation. NE répondez PAS à la question, reformulez-la "
|
34 |
+
"si nécessaire, sinon retournez-la telle quelle."
|
35 |
+
)
|
36 |
+
|
37 |
+
CONTEXTUEL_QUERY_PROMPT = ChatPromptTemplate.from_messages(
|
38 |
+
[
|
39 |
+
SystemMessagePromptTemplate.from_template(contextualize_q_system_prompt),
|
40 |
+
MessagesPlaceholder("chat_history"),
|
41 |
+
HumanMessagePromptTemplate.from_template("{input}"),
|
42 |
+
]
|
43 |
+
)
|
src/rag_pipeline/rag_system.py
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
-
import
|
2 |
from typing import List, Optional
|
3 |
|
4 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from langchain.chains.retrieval import create_retrieval_chain
|
6 |
-
from langchain.prompts import PromptTemplate
|
7 |
-
from langchain_core.runnables import Runnable
|
8 |
|
9 |
-
from ..utilities.llm_models import get_llm_model_chat
|
10 |
from ..vector_store.vector_store import VectorStoreManager
|
|
|
11 |
|
12 |
|
13 |
class RAGSystem:
|
@@ -19,19 +24,16 @@ class RAGSystem:
|
|
19 |
top_k_documents=5,
|
20 |
):
|
21 |
self.top_k_documents = top_k_documents
|
22 |
-
self.embeddings = self._get_embeddings()
|
23 |
self.llm = self._get_llm()
|
24 |
-
self.chain: Optional[
|
25 |
self.vector_store_management = VectorStoreManager(
|
26 |
docs_dir, persist_directory_dir, batch_size
|
27 |
)
|
28 |
|
29 |
-
def _get_llm(
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
"""Initialize embeddings based on environment configuration"""
|
34 |
-
return get_llm_model_embedding()
|
35 |
|
36 |
def initialize_vector_store(self, documents: List = None):
|
37 |
"""Initialize or load the vector store"""
|
@@ -40,79 +42,28 @@ class RAGSystem:
|
|
40 |
def setup_rag_chain(self):
|
41 |
if self.chain is not None:
|
42 |
return
|
43 |
-
"""Set up the RAG chain with custom prompt"""
|
44 |
-
prompt_template = """Inspirez vous du contexte fourni ci-dessous pour répondre à la question qui suit de la manière la plus précise possible.
|
45 |
-
Si la réponse ne peut pas être déterminée à partir du contexte, évitez d'inventer des informations.
|
46 |
-
L'historique ici fait référence aux précédents échanges avec un utilisateur, tu devrais l'ignore si aucun rapport avec la question posée.
|
47 |
-
Tes réponses doivent être naturelles sous forme de faits, au lieu de faire mention du fait que réponds en fonction d'un contexte.
|
48 |
-
|
49 |
-
**Historique** :
|
50 |
-
{history}
|
51 |
-
|
52 |
-
**Contexte** :
|
53 |
-
{context}
|
54 |
-
|
55 |
-
**Question** :
|
56 |
-
{input}
|
57 |
-
|
58 |
-
Réponse (Vous devez répondre dans la même langue que celle de la question) :"""
|
59 |
-
|
60 |
-
prompt = PromptTemplate(
|
61 |
-
template=prompt_template, input_variables=["context", "input", "history"]
|
62 |
-
)
|
63 |
retriever = self.vector_store_management.vector_store.as_retriever(
|
64 |
search_kwargs={"k": self.top_k_documents}
|
65 |
)
|
66 |
-
question_answer_chain = create_stuff_documents_chain(self.llm, prompt)
|
67 |
-
|
68 |
-
self.chain = create_retrieval_chain(retriever, question_answer_chain)
|
69 |
-
|
70 |
-
def query(self, question: str, history: List[tuple[str]] = []):
|
71 |
-
"""Query the RAG system"""
|
72 |
-
if not self.vector_store_management.vector_store:
|
73 |
-
self.initialize_vector_store()
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
history_text = "\n".join(
|
79 |
-
[f"Utilisateur: {user}\nAssistant: {assistant}" for user, assistant in history]
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
return {
|
85 |
-
"answer": response["answer"],
|
86 |
-
"source_documents": [doc.page_content for doc in response["context"]],
|
87 |
-
}
|
88 |
-
|
89 |
-
def query_iter(self, question: str, history: List[tuple[str]] = []):
|
90 |
"""Query the RAG system"""
|
91 |
if not self.vector_store_management.vector_store:
|
92 |
self.initialize_vector_store()
|
93 |
|
94 |
self.setup_rag_chain()
|
95 |
|
96 |
-
|
97 |
-
history_text = "\n".join(
|
98 |
-
[f"Utilisateur: {user}\nAssistant: {assistant}" for user, assistant in history]
|
99 |
-
)
|
100 |
-
|
101 |
-
for token in self.chain.stream({"input": question, "history": history_text}):
|
102 |
if "answer" in token:
|
103 |
yield token["answer"]
|
104 |
-
|
105 |
-
|
106 |
-
if __name__ == "__main__":
|
107 |
-
from glob import glob
|
108 |
-
|
109 |
-
docs_dir = "data/docs"
|
110 |
-
persist_directory_dir = "data/chroma_db"
|
111 |
-
batch_size = 64
|
112 |
-
|
113 |
-
# Initialize RAG system
|
114 |
-
rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)
|
115 |
-
|
116 |
-
rag.initialize_vector_store() # vector store initialized
|
117 |
-
|
118 |
-
print(rag.query("Quand a eu lieu la traite négrière ?"))
|
|
|
1 |
+
import logging
|
2 |
from typing import List, Optional
|
3 |
|
4 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
5 |
+
from langchain.chains.conversational_retrieval.base import (
|
6 |
+
ConversationalRetrievalChain,
|
7 |
+
)
|
8 |
+
from langchain.chains.history_aware_retriever import (
|
9 |
+
create_history_aware_retriever,
|
10 |
+
)
|
11 |
from langchain.chains.retrieval import create_retrieval_chain
|
|
|
|
|
12 |
|
13 |
+
from ..utilities.llm_models import get_llm_model_chat
|
14 |
from ..vector_store.vector_store import VectorStoreManager
|
15 |
+
from .prompts import CHAT_PROMPT, CONTEXTUEL_QUERY_PROMPT
|
16 |
|
17 |
|
18 |
class RAGSystem:
|
|
|
24 |
top_k_documents=5,
|
25 |
):
|
26 |
self.top_k_documents = top_k_documents
|
|
|
27 |
self.llm = self._get_llm()
|
28 |
+
self.chain: Optional[ConversationalRetrievalChain] = None
|
29 |
self.vector_store_management = VectorStoreManager(
|
30 |
docs_dir, persist_directory_dir, batch_size
|
31 |
)
|
32 |
|
33 |
+
def _get_llm(
|
34 |
+
self,
|
35 |
+
):
|
36 |
+
return get_llm_model_chat(temperature=0.1, max_tokens=1000)
|
|
|
|
|
37 |
|
38 |
def initialize_vector_store(self, documents: List = None):
|
39 |
"""Initialize or load the vector store"""
|
|
|
42 |
def setup_rag_chain(self):
|
43 |
if self.chain is not None:
|
44 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
retriever = self.vector_store_management.vector_store.as_retriever(
|
46 |
search_kwargs={"k": self.top_k_documents}
|
47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
# Contextualize question
|
50 |
+
history_aware_retriever = create_history_aware_retriever(
|
51 |
+
self.llm, retriever, CONTEXTUEL_QUERY_PROMPT
|
|
|
|
|
52 |
)
|
53 |
+
question_answer_chain = create_stuff_documents_chain(self.llm, CHAT_PROMPT)
|
54 |
+
self.chain = create_retrieval_chain(
|
55 |
+
history_aware_retriever, question_answer_chain
|
56 |
+
)
|
57 |
+
logging.info("RAG chain setup complete" + str(self.chain))
|
58 |
+
return self.chain
|
59 |
|
60 |
+
def query(self, question: str, history: list = []):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
"""Query the RAG system"""
|
62 |
if not self.vector_store_management.vector_store:
|
63 |
self.initialize_vector_store()
|
64 |
|
65 |
self.setup_rag_chain()
|
66 |
|
67 |
+
for token in self.chain.stream({"input": question, "chat_history": history}):
|
|
|
|
|
|
|
|
|
|
|
68 |
if "answer" in token:
|
69 |
yield token["answer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utilities/llm_models.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import os
|
2 |
from enum import Enum
|
3 |
-
from typing import Union
|
4 |
|
5 |
from langchain_groq import ChatGroq
|
6 |
-
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
7 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
8 |
|
9 |
|
10 |
class LLMModel(Enum):
|
@@ -12,13 +11,11 @@ class LLMModel(Enum):
|
|
12 |
GROQ = ChatGroq
|
13 |
|
14 |
|
15 |
-
def get_llm_model_chat(
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
if model_type == LLMModel.OLLAMA:
|
21 |
-
return model_type.value(
|
22 |
model=os.getenv("OLLAMA_MODEL"),
|
23 |
temperature=temperature,
|
24 |
max_tokens=max_tokens,
|
@@ -30,7 +27,7 @@ def get_llm_model_chat(
|
|
30 |
}
|
31 |
},
|
32 |
)
|
33 |
-
return
|
34 |
model=os.getenv("GROQ_MODEL_NAME"),
|
35 |
temperature=temperature,
|
36 |
max_tokens=max_tokens,
|
@@ -38,7 +35,7 @@ def get_llm_model_chat(
|
|
38 |
|
39 |
|
40 |
def get_llm_model_embedding():
|
41 |
-
if os.getenv("
|
42 |
return HuggingFaceEmbeddings(
|
43 |
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
|
44 |
model_kwargs={"device": "cpu"},
|
|
|
1 |
import os
|
2 |
from enum import Enum
|
|
|
3 |
|
4 |
from langchain_groq import ChatGroq
|
|
|
5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
+
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
7 |
|
8 |
|
9 |
class LLMModel(Enum):
|
|
|
11 |
GROQ = ChatGroq
|
12 |
|
13 |
|
14 |
+
def get_llm_model_chat(temperature=0.01, max_tokens=None):
|
15 |
+
if str(os.getenv("USE_OLLAMA_CHAT")) == "1" and "localhost" not in str(
|
16 |
+
os.getenv("OLLAMA_HOST")
|
17 |
+
):
|
18 |
+
return ChatOllama(
|
|
|
|
|
19 |
model=os.getenv("OLLAMA_MODEL"),
|
20 |
temperature=temperature,
|
21 |
max_tokens=max_tokens,
|
|
|
27 |
}
|
28 |
},
|
29 |
)
|
30 |
+
return ChatGroq(
|
31 |
model=os.getenv("GROQ_MODEL_NAME"),
|
32 |
temperature=temperature,
|
33 |
max_tokens=max_tokens,
|
|
|
35 |
|
36 |
|
37 |
def get_llm_model_embedding():
|
38 |
+
if str(os.getenv("USE_HF_EMBEDDING")) == "1":
|
39 |
return HuggingFaceEmbeddings(
|
40 |
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
|
41 |
model_kwargs={"device": "cpu"},
|
src/vector_store/vector_store.py
CHANGED
@@ -1,33 +1,11 @@
|
|
1 |
-
import json
|
2 |
import os
|
3 |
-
from concurrent.futures import ThreadPoolExecutor
|
4 |
-
from glob import glob
|
5 |
from typing import List
|
6 |
|
7 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain_chroma import Chroma
|
9 |
-
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
10 |
-
from langchain_core.documents import Document
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
from ..utilities.llm_models import get_llm_model_embedding
|
14 |
|
15 |
-
|
16 |
-
def sanitize_metadata(metadata):
|
17 |
-
sanitized = {}
|
18 |
-
for key, value in metadata.items():
|
19 |
-
if isinstance(value, list):
|
20 |
-
# Convert lists to comma-separated strings or handle appropriately
|
21 |
-
sanitized[key] = ", ".join(value)
|
22 |
-
elif isinstance(value, (str, int, float, bool)):
|
23 |
-
sanitized[key] = value
|
24 |
-
else:
|
25 |
-
raise ValueError(
|
26 |
-
f"Unsupported metadata type for key '{key}': {type(value)}"
|
27 |
-
)
|
28 |
-
return sanitized
|
29 |
-
|
30 |
-
|
31 |
class VectorStoreManager:
|
32 |
def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
|
33 |
self.embeddings = get_llm_model_embedding()
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
from typing import List
|
3 |
|
|
|
4 |
from langchain_chroma import Chroma
|
|
|
|
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
from ..utilities.llm_models import get_llm_model_embedding
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class VectorStoreManager:
|
10 |
def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
|
11 |
self.embeddings = get_llm_model_embedding()
|