mitulagr2's picture
Update rag.py
bdc84e2
raw
history blame
6.28 kB
import os
import logging
from llama_index.core import (
SimpleDirectoryReader,
# VectorStoreIndex,
StorageContext,
Settings,
get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.core.retrievers import VectorIndexRetriever
# from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.response_synthesizers import ResponseMode
# from transformers import AutoTokenizer
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.embeddings.fastembed import FastEmbedEmbedding
QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ChatPDF:
query_engine = None
model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf"
# model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"
# def messages_to_prompt(messages):
# prompt = ""
# for message in messages:
# if message.role == 'system':
# prompt += f"<|system|>\n{message.content}</s>\n"
# elif message.role == 'user':
# prompt += f"<|user|>\n{message.content}</s>\n"
# elif message.role == 'assistant':
# prompt += f"<|assistant|>\n{message.content}</s>\n"
# if not prompt.startswith("<|system|>\n"):
# prompt = "<|system|>\n</s>\n" + prompt
# prompt = prompt + "<|assistant|>\n"
# return prompt
# def completion_to_prompt(completion):
# return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
def __init__(self):
self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20)
logger.info("initializing the vector store related objects")
# client = QdrantClient(host="localhost", port=6333)
# client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
client = QdrantClient(":memory:")
self.vector_store = QdrantVectorStore(
client=client,
collection_name="rag_documents",
# enable_hybrid=True
)
logger.info("initializing the FastEmbedEmbedding")
self.embed_model = FastEmbedEmbedding(
# model_name="BAAI/bge-small-en"
)
llm = LlamaCPP(
model_url=self.model_url,
temperature=0.1,
max_new_tokens=256,
context_window=3900,
# generate_kwargs={},
# model_kwargs={"n_gpu_layers": -1},
# messages_to_prompt=self.messages_to_prompt,
# completion_to_prompt=self.completion_to_prompt,
verbose=True,
)
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# tokenizer.save_pretrained("./models/tokenizer/")
logger.info("initializing the global settings")
Settings.text_splitter = self.text_parser
Settings.embed_model = self.embed_model
Settings.llm = llm
# Settings.tokenzier = tokenizer
Settings.transformations = [self.text_parser]
def ingest(self, files_dir: str):
text_chunks = []
doc_ids = []
nodes = []
docs = SimpleDirectoryReader(input_dir=files_dir).load_data()
logger.info("enumerating docs")
for doc_idx, doc in enumerate(docs):
curr_text_chunks = self.text_parser.split_text(doc.text)
text_chunks.extend(curr_text_chunks)
doc_ids.extend([doc_idx] * len(curr_text_chunks))
logger.info("enumerating text_chunks")
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(text=text_chunk)
src_doc = docs[doc_ids[idx]]
node.metadata = src_doc.metadata
nodes.append(node)
logger.info("enumerating nodes")
for node in nodes:
node_embedding = self.embed_model.get_text_embedding(
node.get_content(metadata_mode=MetadataMode.ALL)
)
node.embedding = node_embedding
logger.info("initializing the storage context")
storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
logger.info("indexing the nodes in VectorStoreIndex")
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
transformations=Settings.transformations,
)
# logger.info("configure retriever")
# retriever = VectorIndexRetriever(
# index=index,
# similarity_top_k=6,
# # vector_store_query_mode="hybrid"
# )
# logger.info("configure response synthesizer")
# response_synthesizer = get_response_synthesizer(
# # streaming=True,
# response_mode=ResponseMode.COMPACT,
# )
# logger.info("assemble query engine")
# self.query_engine = RetrieverQueryEngine(
# retriever=retriever,
# response_synthesizer=response_synthesizer,
# )
self.query_engine = index.as_query_engine()
# logger.info("creating the HyDEQueryTransform instance")
# hyde = HyDEQueryTransform(include_original=True)
# self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
def ask(self, query: str):
if not self.query_engine:
return "Please, add a PDF document first."
logger.info("retrieving the response to the query")
# response = self.query_engine.query(str_or_query_bundle=query)
response = self.query_engine.query(query)
print(response)
return response
def clear(self):
self.query_engine = None