Programmes commited on
Commit
87392ed
·
verified ·
1 Parent(s): 4b4260f

Update rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +26 -11
rag_utils.py CHANGED
@@ -1,10 +1,11 @@
1
 
 
2
  import faiss
3
  import pickle
4
- from sentence_transformers import SentenceTransformer
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
- import torch
7
  import numpy as np
 
 
 
8
 
9
  def load_faiss_index(index_path="faiss_index/faiss_index.faiss", doc_path="faiss_index/documents.pkl"):
10
  index = faiss.read_index(index_path)
@@ -13,19 +14,33 @@ def load_faiss_index(index_path="faiss_index/faiss_index.faiss", doc_path="faiss
13
  return index, documents
14
 
15
  def get_embedding_model():
16
- return SentenceTransformer("all-MiniLM-L6-v2")
 
 
17
 
18
  def query_index(question, index, documents, model, k=3):
19
  question_embedding = model.encode([question])
20
  _, indices = index.search(np.array(question_embedding).astype("float32"), k)
21
- results = [documents[i] for i in indices[0]]
22
- return results
23
 
24
  def generate_answer(question, context):
25
- model_id = "mistralai/Mistral-7B-Instruct-v0.1"
26
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
28
  prompt = f"Voici un contexte :\n{context}\n\nQuestion : {question}\nRéponse :"
29
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
30
- outputs = model.generate(**inputs, max_new_tokens=256)
 
 
31
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
1
 
2
+ import os
3
  import faiss
4
  import pickle
 
 
 
5
  import numpy as np
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  def load_faiss_index(index_path="faiss_index/faiss_index.faiss", doc_path="faiss_index/documents.pkl"):
11
  index = faiss.read_index(index_path)
 
14
  return index, documents
15
 
16
  def get_embedding_model():
17
+ # Pas besoin de token ici, modèle public
18
+ print("✅ Chargement de l'encodeur multi-qa-MiniLM-L6-cos-v1")
19
+ return SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
20
 
21
  def query_index(question, index, documents, model, k=3):
22
  question_embedding = model.encode([question])
23
  _, indices = index.search(np.array(question_embedding).astype("float32"), k)
24
+ return [documents[i] for i in indices[0]]
 
25
 
26
  def generate_answer(question, context):
27
+ token = os.getenv("HUGGINGFACE") # requis pour Mistral
28
+ model_id = "mgoogle/flan-t5-base"
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ token=token,
36
+ device_map="auto",
37
+ torch_dtype=torch.float16
38
+ )
39
+
40
  prompt = f"Voici un contexte :\n{context}\n\nQuestion : {question}\nRéponse :"
41
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
42
+ outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
43
+ print("🔍 Contexte utilisé pour la génération :")
44
+ print(context[:500])
45
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+