gc / rag.py
ionosphere's picture
Update with Mistral and rag PDF
f676a1d
raw
history blame
3.11 kB
import os
# __import__('pysqlite3')
# import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores.utils import filter_complex_metadata
#add new import
from langchain_community.document_loaders.csv_loader import CSVLoader
# load .env in local dev
load_dotenv()
env_api_key = os.environ.get("MISTRAL_API_KEY")
llm_model = "open-mixtral-8x7b"
class ChatPDF:
vector_store = None
retriever = None
chain = None
def __init__(self):
# https://python.langchain.com/docs/integrations/chat/mistralai/
self.model = ChatMistralAI(model=llm_model)
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
self.prompt = PromptTemplate.from_template(
"""
<s> [INST] Vous échangez en français et avec précision.
Vous êtes un assistant comptable spécialisé dans la comptabilité agricole en grandes cultures.
Vous devez analyser les documents ci-dessous et calculer les couts de productions.
Les documents fournis représente la comptabilité de l'exploitation agricole.
Vous devez répondre sous forme de tableaux et de textes.
Vous devez répondre de façon synthétique et argumentée.
[/INST] </s>
[INST]
Question: {question}
Context: {context}
Answer: [/INST]
"""
)
def ingest(self, pdf_file_path: str):
docs = PyPDFLoader(file_path=pdf_file_path).load()
chunks = self.text_splitter.split_documents(docs)
chunks = filter_complex_metadata(chunks)
embeddings = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
vector_store = FAISS.from_documents(chunks, embeddings)
# vector_store = Chroma.from_documents(documents=chunks, embedding=embeddings)
self.retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 3,
"score_threshold": 0.5,
},
)
self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.model
| StrOutputParser())
def ask(self, query: str):
if not self.chain:
return "Ajouter un document PDF d'abord."
return self.chain.invoke(query)
def clear(self):
self.vector_store = None
self.retriever = None
self.chain = None