RAGtest / rag_pipeline.py
willco-afk's picture
Create rag_pipeline.py
52bbfc2 verified
raw
history blame
810 Bytes
import faiss
import pickle
from transformers import pipeline
# Load FAISS index
with open('faiss_index.index', 'rb') as f:
faiss_index = pickle.load(f)
# Load a pre-trained generative model (e.g., GPT-3 or T5)
generator = pipeline("text-generation", model="gpt2")
# Example query
query = "What is the capital of France?"
# Search for the most similar document using FAISS
query_embedding = model.encode([query])
D, I = faiss_index.search(query_embedding, k=1) # k=1 for the most similar document
# Use the retrieved document as context for the generative model
retrieved_doc = documents[I[0][0]]
# Generate a response using the retrieved document as context
prompt = f"Context: {retrieved_doc}\nQuestion: {query}\nAnswer:"
answer = generator(prompt, max_length=50)
print(answer[0]['generated_text'])