Graduation / pipelines /eval_rag.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
5.47 kB
import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from blueprints.rag_utils import format_docs
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
from langchain_mistralai import ChatMistralAI
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
# from langchain_groq import ChatGroq
import time
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
load_dotenv()
HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small')
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of lines."""
def parse(self, text: str) -> List[str]:
lines = text.strip().split("\n")
return list(filter(None, lines)) # Remove empty lines
# def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
# meta_data_docs = []
# for doc in docs:
# meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
# meta_data_docs.append(meta_data_doc)
# return meta_data_docs
# def search_with_filter(query, vector_store, k, headers):
# conditions = [
# models.FieldCondition(
# key="metadata.Header_1",
# match=models.MatchValue(
# value=headers[0]
# ),
# ),
# models.FieldCondition(
# key="metadata.Header_2",
# match=models.MatchValue(
# value=headers[1]
# ),
# ),
# ]
# if len(headers) == 3:
# conditions.append(
# models.FieldCondition(
# key="metadata.Header_3",
# match=models.MatchValue(
# value=headers[2]
# ),
# )
# )
# single_result = vector_store.similarity_search(
# query=query,
# k=k,
# filter=models.Filter(
# must=conditions
# ),
# )
# return single_result
if __name__ == "__main__":
client = QdrantClient(url="http://localhost:6333")
stsv = Qdrant(client, collection_name="eval_collection2", embeddings= HF_EMBEDDING)
retriever = stsv.as_retriever(search_kwargs={'k': 3})
import pickle
with open('/home/justtuananh/AI4TUAN/DOAN2024/offical/pipelines/documents_new.pkl', 'rb') as f:
sotaysinhvien = pickle.load(f)
retriever_bm25 = BM25SRetriever.from_documents(sotaysinhvien, k= 5, activate_numba = True)
# reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
llm = ChatMistralAI(
model="mistral-large-2407",
temperature=0,
max_retries=2,
)
# llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_4'))
# llm2 = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5'))
output_parser = LineListOutputParser()
llm_chain = QUERY_PROMPT | llm | output_parser
messages = [
{"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
]
def duy_phen():
user_message = input("Nhập câu hỏi của bạn: ")
start_time = time.time()
if retriever is not None:
retriever_multi = MultiQueryRetriever(
retriever=retriever, llm_chain=llm_chain, parser_key="lines"
)
ensemble_retriever = EnsembleRetriever(
retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5])
compressor = LLMChainExtractor.from_llm(llm)
# _filter = LLMChainFilter.from_llm(llm)
_filter2 = LLMListwiseRerank.from_llm(llm, top_n=5)
embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)
compression = ContextualCompressionRetriever(
base_compressor=_filter2, base_retriever=ensemble_retriever
)
# rag_chain = (
# {"context": compression | format_docs, "question": RunnablePassthrough()}
# | basic_template | llm2 | StrOutputParser()
# )
print(compression.invoke(f"{user_message}"))
end_time = time.time()
print(f'TIME USING : {end_time-start_time}')
else:
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')
duy_phen()