import os import time from typing import List from qdrant_client import QdrantClient, models from langchain_core.documents import Document from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings from blueprints.rag_utils import format_docs from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt from langchain.retrievers import EnsembleRetriever from BM25 import BM25SRetriever from langchain_mistralai import ChatMistralAI # from database_Routing import DB_Router from langchain.retrievers.multi_query import MultiQueryRetriever from langchain.retrievers.document_compressors import LLMChainExtractor from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import BaseOutputParser # from langchain_cohere import CohereRerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever # from langchain_groq import ChatGroq import time from qdrant_client import QdrantClient from langchain_community.vectorstores import Qdrant from langchain.retrievers.document_compressors import LLMChainFilter from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.retrievers.document_compressors import LLMListwiseRerank from langchain_openai import OpenAIEmbeddings from dotenv import load_dotenv load_dotenv() HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small') class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return list(filter(None, lines)) # Remove empty lines # def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')): # meta_data_docs = [] # for doc in docs: # meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)] # meta_data_docs.append(meta_data_doc) # return meta_data_docs # def search_with_filter(query, vector_store, k, headers): # conditions = [ # models.FieldCondition( # key="metadata.Header_1", # match=models.MatchValue( # value=headers[0] # ), # ), # models.FieldCondition( # key="metadata.Header_2", # match=models.MatchValue( # value=headers[1] # ), # ), # ] # if len(headers) == 3: # conditions.append( # models.FieldCondition( # key="metadata.Header_3", # match=models.MatchValue( # value=headers[2] # ), # ) # ) # single_result = vector_store.similarity_search( # query=query, # k=k, # filter=models.Filter( # must=conditions # ), # ) # return single_result if __name__ == "__main__": client = QdrantClient(url="http://localhost:6333") stsv = Qdrant(client, collection_name="eval_collection2", embeddings= HF_EMBEDDING) retriever = stsv.as_retriever(search_kwargs={'k': 3}) import pickle with open('/home/justtuananh/AI4TUAN/DOAN2024/offical/pipelines/documents_new.pkl', 'rb') as f: sotaysinhvien = pickle.load(f) retriever_bm25 = BM25SRetriever.from_documents(sotaysinhvien, k= 5, activate_numba = True) # reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5) llm = ChatMistralAI( model="mistral-large-2407", temperature=0, max_retries=2, ) # llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_4')) # llm2 = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5')) output_parser = LineListOutputParser() llm_chain = QUERY_PROMPT | llm | output_parser messages = [ {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"} ] def duy_phen(): user_message = input("Nhập câu hỏi của bạn: ") start_time = time.time() if retriever is not None: retriever_multi = MultiQueryRetriever( retriever=retriever, llm_chain=llm_chain, parser_key="lines" ) ensemble_retriever = EnsembleRetriever( retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5]) compressor = LLMChainExtractor.from_llm(llm) # _filter = LLMChainFilter.from_llm(llm) _filter2 = LLMListwiseRerank.from_llm(llm, top_n=5) embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5) compression = ContextualCompressionRetriever( base_compressor=_filter2, base_retriever=ensemble_retriever ) # rag_chain = ( # {"context": compression | format_docs, "question": RunnablePassthrough()} # | basic_template | llm2 | StrOutputParser() # ) print(compression.invoke(f"{user_message}")) end_time = time.time() print(f'TIME USING : {end_time-start_time}') else: print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') duy_phen()