|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain.schema.document import Document |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_chroma import Chroma |
|
import spaces |
|
from langchain_text_splitters import MarkdownHeaderTextSplitter |
|
import os |
|
from transformers import AutoTokenizer |
|
api_token = os.getenv("HF_TOKEN") |
|
model_name = "meta-llama/Llama-3.1-8B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) |
|
|
|
embedding_model = HuggingFaceBgeEmbeddings( |
|
model_name="BAAI/bge-large-en-v1.5", |
|
model_kwargs={"device": "cuda"}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
query_instruction="" |
|
) |
|
|
|
|
|
def create_rag_index(text_no_prefix): |
|
"""Loads the PDF, splits its text, and builds a vectorstore for naive RAG.""" |
|
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
|
tokenizer, |
|
chunk_size=256, |
|
chunk_overlap=0, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
|
|
docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] |
|
|
|
vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model) |
|
return vectorstore |
|
|
|
def run_naive_rag_query(vectorstore, query, rag_token_size, prefix, task, few_shot_examples): |
|
""" |
|
For naive RAG, retrieves top-k chunks (k based on target token size) |
|
and generates an answer using those chunks. |
|
""" |
|
k = max(1, rag_token_size // 256) |
|
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) |
|
retrieved_docs = retriever.invoke(query) |
|
for doc in retrieved_docs: |
|
print("=================") |
|
print(doc.page_content) |
|
print("=================") |
|
formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) |
|
|
|
rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples |
|
|
|
return rag_context |