File size: 4,276 Bytes
553f828
 
 
606851f
 
 
553f828
 
 
c5b694d
 
 
553f828
 
 
 
 
 
606851f
553f828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee1daeb
606851f
553f828
 
606851f
553f828
 
 
 
 
 
 
606851f
553f828
606851f
 
 
 
 
553f828
 
606851f
 
 
 
553f828
 
606851f
 
553f828
 
 
 
 
 
 
 
 
 
 
 
 
606851f
 
553f828
 
 
 
c5b694d
 
 
 
 
 
 
553f828
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import faiss
from huggingface_hub import InferenceClient
from transformers import AutoConfig, AutoModel, AutoTokenizer

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse

from fastapi.middleware.cors import CORSMiddleware


embedding_model_name = "intfloat/multilingual-e5-large"
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
embedding_model = AutoModel.from_pretrained(embedding_model_name)

def embed_texts(texts):
    """Generate embeddings for a list of texts."""
    inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = embedding_model(**inputs)
    # Use mean pooling to get embeddings
    embeddings = torch.mean(outputs.last_hidden_state, dim=1)
    return embeddings.numpy()

# Function to load the FAISS index and document mapping
def load_faiss_index_and_mapping(index_path="document_index.faiss", mapping_path="document_mapping.txt"):
    """Loads the FAISS index and document mapping from files."""
    faiss_index = faiss.read_index(index_path)  # Load the FAISS index
    
    document_mapping = {}  # Dictionary to store document mapping
    with open(mapping_path, "r") as f:
        for line in f:
            index, filename = line.strip().split("\t")
            document_mapping[int(index)] = filename

    return faiss_index, document_mapping

# Load the index and mapping
faiss_index, document_mapping = load_faiss_index_and_mapping()

# Function to load documents (keep or modify to load based on document_mapping)
def load_documents(document_mapping, folder_path="Data"):
    """Loads documents based on the document mapping."""
    documents = []
    for index in sorted(document_mapping.keys()):
        filename = document_mapping[index]
        file_path = os.path.join(folder_path, filename)
        with open(file_path, "r", encoding="utf-8") as file:
            documents.append(file.read())
    return documents


documents = load_documents(document_mapping)
print(f"Loaded {len(documents)} documents.")

secret = os.environ["API_TOKEN"]
client = InferenceClient(api_key=secret)

def generate_response(query, retrieved_docs):
    """Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer."""
    context = " ".join(retrieved_docs)
    prompt = (
        f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
        f"Contexte : {context}\n\n"
        f"Question : {query}\n\n"
        f"Réponse :"
    )
    
    messages = [
        {"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."},
	    {
		"role": "user",
		"content": prompt,
    	}
    ]

    completion = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct", 
    	messages=messages, 
        max_tokens=500,
    )

    return completion.choices[0].message.content

# 6. Query and Retrieve Relevant Documents
def retrieve_documents(query, k=3):
    """Retrieve the top-k most relevant documents."""
    query_embedding = embed_texts([query])
    distances, indices = faiss_index.search(query_embedding.astype('float32'), k)
    return [documents[i] for i in indices[0]]

def rag_pipeline(query):
    """Complete RAG pipeline."""
    # Step 1: Retrieve relevant documents
    relevant_docs = retrieve_documents(query, 1)
    # Step 2: Generate a response using the retrieved documents
    response = generate_response(query, relevant_docs)
    print("Query:", query)
    print("Response:", response)
    return response

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Replace '*' with specific domains in production for security
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
@app.get("/generate")
async def generate(query: str = None):
    if not query:
        raise HTTPException(status_code=400, detail="Query parameter is required")
    
    response = rag_pipeline(query)
    return JSONResponse(content={"response": response})