Spaces:
Sleeping
Sleeping
Update rag.py
Browse files- app/rag.py +35 -31
app/rag.py
CHANGED
|
@@ -27,13 +27,14 @@ from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
|
| 27 |
QDRANT_API_URL = os.getenv('QDRANT_API_URL')
|
| 28 |
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
class ChatPDF:
|
| 31 |
-
logging.basicConfig(level=logging.INFO)
|
| 32 |
-
logger = logging.getLogger(__name__)
|
| 33 |
query_engine = None
|
| 34 |
|
| 35 |
-
|
| 36 |
-
model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"
|
| 37 |
|
| 38 |
# def messages_to_prompt(messages):
|
| 39 |
# prompt = ""
|
|
@@ -59,7 +60,7 @@ class ChatPDF:
|
|
| 59 |
def __init__(self):
|
| 60 |
self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20)
|
| 61 |
|
| 62 |
-
|
| 63 |
# client = QdrantClient(host="localhost", port=6333)
|
| 64 |
# client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
|
| 65 |
client = QdrantClient(":memory:")
|
|
@@ -69,7 +70,7 @@ class ChatPDF:
|
|
| 69 |
# enable_hybrid=True
|
| 70 |
)
|
| 71 |
|
| 72 |
-
|
| 73 |
self.embed_model = FastEmbedEmbedding(
|
| 74 |
# model_name="BAAI/bge-small-en"
|
| 75 |
)
|
|
@@ -89,7 +90,7 @@ class ChatPDF:
|
|
| 89 |
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
| 90 |
# tokenizer.save_pretrained("./models/tokenizer/")
|
| 91 |
|
| 92 |
-
|
| 93 |
Settings.text_splitter = self.text_parser
|
| 94 |
Settings.embed_model = self.embed_model
|
| 95 |
Settings.llm = llm
|
|
@@ -103,55 +104,57 @@ class ChatPDF:
|
|
| 103 |
|
| 104 |
docs = SimpleDirectoryReader(input_dir=files_dir).load_data()
|
| 105 |
|
| 106 |
-
|
| 107 |
for doc_idx, doc in enumerate(docs):
|
| 108 |
curr_text_chunks = self.text_parser.split_text(doc.text)
|
| 109 |
text_chunks.extend(curr_text_chunks)
|
| 110 |
doc_ids.extend([doc_idx] * len(curr_text_chunks))
|
| 111 |
|
| 112 |
-
|
| 113 |
for idx, text_chunk in enumerate(text_chunks):
|
| 114 |
node = TextNode(text=text_chunk)
|
| 115 |
src_doc = docs[doc_ids[idx]]
|
| 116 |
node.metadata = src_doc.metadata
|
| 117 |
nodes.append(node)
|
| 118 |
|
| 119 |
-
|
| 120 |
for node in nodes:
|
| 121 |
node_embedding = self.embed_model.get_text_embedding(
|
| 122 |
node.get_content(metadata_mode=MetadataMode.ALL)
|
| 123 |
)
|
| 124 |
node.embedding = node_embedding
|
| 125 |
|
| 126 |
-
|
| 127 |
storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
|
| 128 |
-
|
| 129 |
index = VectorStoreIndex(
|
| 130 |
nodes=nodes,
|
| 131 |
storage_context=storage_context,
|
| 132 |
transformations=Settings.transformations,
|
| 133 |
)
|
| 134 |
|
| 135 |
-
|
| 136 |
-
retriever = VectorIndexRetriever(
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
response_synthesizer = get_response_synthesizer(
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
)
|
| 147 |
|
| 148 |
-
|
| 149 |
-
self.query_engine = RetrieverQueryEngine(
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
)
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
#
|
| 155 |
# hyde = HyDEQueryTransform(include_original=True)
|
| 156 |
# self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
|
| 157 |
|
|
@@ -159,8 +162,9 @@ class ChatPDF:
|
|
| 159 |
if not self.query_engine:
|
| 160 |
return "Please, add a PDF document first."
|
| 161 |
|
| 162 |
-
|
| 163 |
-
response = self.query_engine.query(str_or_query_bundle=query)
|
|
|
|
| 164 |
print(response)
|
| 165 |
return response
|
| 166 |
|
|
|
|
| 27 |
QDRANT_API_URL = os.getenv('QDRANT_API_URL')
|
| 28 |
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
|
| 29 |
|
| 30 |
+
logging.basicConfig(level=logging.INFO)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
class ChatPDF:
|
|
|
|
|
|
|
| 34 |
query_engine = None
|
| 35 |
|
| 36 |
+
model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf"
|
| 37 |
+
# model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"
|
| 38 |
|
| 39 |
# def messages_to_prompt(messages):
|
| 40 |
# prompt = ""
|
|
|
|
| 60 |
def __init__(self):
|
| 61 |
self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20)
|
| 62 |
|
| 63 |
+
logger.info("initializing the vector store related objects")
|
| 64 |
# client = QdrantClient(host="localhost", port=6333)
|
| 65 |
# client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
|
| 66 |
client = QdrantClient(":memory:")
|
|
|
|
| 70 |
# enable_hybrid=True
|
| 71 |
)
|
| 72 |
|
| 73 |
+
logger.info("initializing the FastEmbedEmbedding")
|
| 74 |
self.embed_model = FastEmbedEmbedding(
|
| 75 |
# model_name="BAAI/bge-small-en"
|
| 76 |
)
|
|
|
|
| 90 |
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
| 91 |
# tokenizer.save_pretrained("./models/tokenizer/")
|
| 92 |
|
| 93 |
+
logger.info("initializing the global settings")
|
| 94 |
Settings.text_splitter = self.text_parser
|
| 95 |
Settings.embed_model = self.embed_model
|
| 96 |
Settings.llm = llm
|
|
|
|
| 104 |
|
| 105 |
docs = SimpleDirectoryReader(input_dir=files_dir).load_data()
|
| 106 |
|
| 107 |
+
logger.info("enumerating docs")
|
| 108 |
for doc_idx, doc in enumerate(docs):
|
| 109 |
curr_text_chunks = self.text_parser.split_text(doc.text)
|
| 110 |
text_chunks.extend(curr_text_chunks)
|
| 111 |
doc_ids.extend([doc_idx] * len(curr_text_chunks))
|
| 112 |
|
| 113 |
+
logger.info("enumerating text_chunks")
|
| 114 |
for idx, text_chunk in enumerate(text_chunks):
|
| 115 |
node = TextNode(text=text_chunk)
|
| 116 |
src_doc = docs[doc_ids[idx]]
|
| 117 |
node.metadata = src_doc.metadata
|
| 118 |
nodes.append(node)
|
| 119 |
|
| 120 |
+
logger.info("enumerating nodes")
|
| 121 |
for node in nodes:
|
| 122 |
node_embedding = self.embed_model.get_text_embedding(
|
| 123 |
node.get_content(metadata_mode=MetadataMode.ALL)
|
| 124 |
)
|
| 125 |
node.embedding = node_embedding
|
| 126 |
|
| 127 |
+
logger.info("initializing the storage context")
|
| 128 |
storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
|
| 129 |
+
logger.info("indexing the nodes in VectorStoreIndex")
|
| 130 |
index = VectorStoreIndex(
|
| 131 |
nodes=nodes,
|
| 132 |
storage_context=storage_context,
|
| 133 |
transformations=Settings.transformations,
|
| 134 |
)
|
| 135 |
|
| 136 |
+
# logger.info("configure retriever")
|
| 137 |
+
# retriever = VectorIndexRetriever(
|
| 138 |
+
# index=index,
|
| 139 |
+
# similarity_top_k=6,
|
| 140 |
+
# # vector_store_query_mode="hybrid"
|
| 141 |
+
# )
|
| 142 |
|
| 143 |
+
# logger.info("configure response synthesizer")
|
| 144 |
+
# response_synthesizer = get_response_synthesizer(
|
| 145 |
+
# # streaming=True,
|
| 146 |
+
# response_mode=ResponseMode.COMPACT,
|
| 147 |
+
# )
|
| 148 |
|
| 149 |
+
# logger.info("assemble query engine")
|
| 150 |
+
# self.query_engine = RetrieverQueryEngine(
|
| 151 |
+
# retriever=retriever,
|
| 152 |
+
# response_synthesizer=response_synthesizer,
|
| 153 |
+
# )
|
| 154 |
+
|
| 155 |
+
self.query_engine = index.as_query_engine()
|
| 156 |
|
| 157 |
+
# logger.info("creating the HyDEQueryTransform instance")
|
| 158 |
# hyde = HyDEQueryTransform(include_original=True)
|
| 159 |
# self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)
|
| 160 |
|
|
|
|
| 162 |
if not self.query_engine:
|
| 163 |
return "Please, add a PDF document first."
|
| 164 |
|
| 165 |
+
logger.info("retrieving the response to the query")
|
| 166 |
+
# response = self.query_engine.query(str_or_query_bundle=query)
|
| 167 |
+
response = self.query_engine.query(query)
|
| 168 |
print(response)
|
| 169 |
return response
|
| 170 |
|