volAI_Avril / RAG_Mistral.py
LostPikachu's picture
Upload 3 files
f056b9f verified
raw
history blame
5.7 kB
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 24 15:51:34 2025
@author: MIPO10053340
C:/Users/MIPO10053340/OneDrive - Groupe Avril/Bureau/Salon_Agriculture_2024/Micka_API_Call/Docs_pdf/Docs_pdf/
"""
# -*- coding: utf-8 -*-
"""
Optimisation du RAG avec MistralAI - Embeddings en batch
"""
import os
import numpy as np
import fitz # PyMuPDF pour extraction PDF
import faiss
import pickle
import matplotlib.pyplot as plt
from mistralai import Mistral
from sklearn.manifold import TSNE
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from dotenv import load_dotenv
# Charger les variables d'environnement
load_dotenv()
MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY_static')
# 📌 Initialisation du client Mistral
client = Mistral(api_key=MISTRAL_API_KEY)
model_embedding = "mistral-embed"
model_chat = "ministral-8b-latest"
temperature = 0.1 # Réduction de la température pour privilégier la RAG
probability = 0.9 # Ajustement de la probabilité pour plus de contrôle
# 📌 Paramètres de segmentation
chunk_size = 256 # Réduction du chunk size pour un meilleur contrôle du contexte
chunk_overlap = 15
# 📌 Définition des chemins de stockage
index_path = "faiss_index.bin"
chunks_path = "chunked_docs.pkl"
# 📌 Vérification et chargement des données
if os.path.exists(index_path) and os.path.exists(chunks_path):
print("🔄 Chargement des données existantes...")
index = faiss.read_index(index_path) # Charger l'index FAISS
with open(chunks_path, "rb") as f:
chunked_docs = pickle.load(f) # Charger les chunks de texte
print("✅ Index et chunks chargés avec succès !")
else:
print("⚡ Création et stockage d'un nouvel index FAISS...")
# 📌 Extraction et segmentation des PDF
pdf_folder = 'C:/Users/MIPO10053340/OneDrive - Groupe Avril/Bureau/Salon_Agriculture_2024/Micka_API_Call/Docs_pdf/'
chunked_docs = SimpleDirectoryReader(pdf_folder).load_data()
chunked_docs = [doc.text for doc in chunked_docs]
# 📌 Génération des embeddings
embeddings = []
batch_size = 5
for i in range(0, len(chunked_docs), batch_size):
batch = chunked_docs[i:i + batch_size]
embeddings_batch_response = client.embeddings.create(
model=model_embedding,
inputs=batch,
)
batch_embeddings = [data.embedding for data in embeddings_batch_response.data]
embeddings.extend(batch_embeddings)
embeddings = np.array(embeddings).astype('float32')
# 📌 Vérification avant d’indexer dans FAISS
if embeddings is None or len(embeddings) == 0:
raise ValueError("⚠️ ERREUR : Aucun embedding généré ! Vérifie l'étape de génération des embeddings.")
# 📌 Création et stockage de l'index FAISS
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
faiss.write_index(index, index_path) # Sauvegarde de l'index
# 📌 Sauvegarde des chunks de texte
with open(chunks_path, "wb") as f:
pickle.dump(chunked_docs, f)
print("✅ Index et chunks sauvegardés !")
# 📌 Récupération des chunks les plus pertinents
def retrieve_relevant_chunks(question, k=5):
"""Recherche les chunks les plus pertinents en fonction de la similarité des embeddings."""
question_embedding_response = client.embeddings.create(
model=model_embedding,
inputs=[question],
)
question_embedding = np.array(question_embedding_response.data[0].embedding).astype('float32').reshape(1, -1)
# Vérification de la compatibilité des dimensions
dimension = index.d
if question_embedding.shape[1] != dimension:
raise ValueError(f"⚠️ ERREUR : La dimension de l'embedding de la question ({question_embedding.shape[1]}) ne correspond pas aux embeddings indexés ({dimension}).")
distances, indices = index.search(question_embedding, k)
if len(indices[0]) == 0:
print("⚠️ Avertissement : Aucun chunk pertinent trouvé, réponse possible moins précise.")
return []
return [chunked_docs[i] for i in indices[0]]
# 📌 Génération de réponse avec MistralAI
def generate_response(context, question):
"""Génère une réponse basée sur le contexte extrait du corpus avec une basse température et un contrôle de probabilité."""
messages = [
{"role": "system", "content": f"Voici des informations contextuelles à utiliser avec priorité : {context}"},
{"role": "user", "content": question}
]
response = client.chat.complete(model=model_chat, messages=messages, temperature=temperature)
return response.choices[0].message.content
# 📌 Exécuter une requête utilisateur
user_question = "Bonjour le Chat, je suis éléveur de poulets depuis plus de 20 ans et j'ai un doctorat de nutrition animale.Qu’est-ce qu’une protéine idéale en poule pondeuse ? Peux-tu suggérer une protéine idéale en pondeuse ? Merci d'être exhaustif et d'approfondir tes réponses et de ne pas survoler le sujet"
relevant_chunks = retrieve_relevant_chunks(user_question)
context = "\n".join(relevant_chunks)
answer = generate_response(context, user_question)
# 📊 Affichage de la réponse
print("\n🔹 Réponse Mistral :")
print(answer)
# 💾 Sauvegarde des résultats
with open("mistral_response_types.txt", "w", encoding="utf-8") as f:
f.write(f"Question : {user_question}\n")
f.write(f"Réponse :\n{answer}\n")
print("\n✅ Réponse enregistrée dans 'mistral_response_types.txt'")