import os import time from typing import List from qdrant_client import QdrantClient, models from langchain_core.documents import Document from semantic_cache.main import SemanticCache from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings from Router.router import Evaluator from langchain_openai import ChatOpenAI # from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message from blueprints.rag_utils import format_docs, translate from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt from SafetyChecker import SafetyChecker from langchain.retrievers import EnsembleRetriever from BM25 import BM25SRetriever # from database_Routing import DB_Router from langchain.retrievers.multi_query import MultiQueryRetriever # import cohere from langchain.retrievers.document_compressors import LLMChainExtractor from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import BaseOutputParser # from langchain_cohere import CohereRerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_groq import ChatGroq from langchain_core.runnables import RunnablePassthrough import time from qdrant_client import QdrantClient from langchain_community.vectorstores import Qdrant from langchain.retrievers.document_compressors import LLMChainFilter from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.retrievers.document_compressors import LLMListwiseRerank from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings load_dotenv() os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY') os.environ["COHERE_API_KEY"] # HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding") HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small', api_key = os.getenv('OPENAI_KEY')) class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return list(filter(None, lines)) # Remove empty lines def add_or_update_system_message(content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list :param msg: The message to be added or appended. :param messages: The list of message dictionaries. :return: The updated list of message dictionaries. """ if messages and messages[0].get("role") == "system": messages[0]["content"] += f"{content}\n" else: # Insert at the beginning messages.insert(0, {"role": "system", "content": content}) return messages def split_context( context): split_index = context.find("User question") system_prompt = context[:split_index].strip() user_question = context[split_index:].strip() user_split_index = user_question.find("") f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:]) return f_system_prompt def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')): meta_data_docs = [] for doc in docs: meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)] meta_data_docs.append(meta_data_doc) return meta_data_docs def search_with_filter(query, vector_store, k, headers): conditions = [] # Xử lý điều kiện theo số lượng headers if len(headers) == 1: conditions.append( models.FieldCondition( key="metadata.Header_1", match=models.MatchValue( value=headers[0] ), ) ) elif len(headers) == 2: conditions.append( models.FieldCondition( key="metadata.Header_1", match=models.MatchValue( value=headers[0] ), ) ) conditions.append( models.FieldCondition( key="metadata.Header_2", match=models.MatchValue( value=headers[1] ), ) ) elif len(headers) == 3: conditions.append( models.FieldCondition( key="metadata.Header_1", match=models.MatchValue( value=headers[0] ), ) ) conditions.append( models.FieldCondition( key="metadata.Header_2", match=models.MatchValue( value=headers[1] ), ) ) conditions.append( models.FieldCondition( key="metadata.Header_3", match=models.MatchValue( value=headers[2] ), ) ) # Thực hiện truy vấn với các điều kiện single_result = vector_store.similarity_search( query=query, k=k, filter=models.Filter( must=conditions ), ) return single_result def get_relevant_documents(documents: List[Document], limit: int) -> List[Document]: result = [] seen = set() for doc in documents: if doc.page_content not in seen: result.append(doc) seen.add(doc.page_content) if len(result) == limit: break return result if __name__ == "__main__": client = QdrantClient( url="http://localhost:6333" ) stsv = Qdrant(client, collection_name="sotaysinhvien_filter", embeddings= HF_EMBEDDING) stsv_db = stsv.as_retriever(search_kwargs={'k': 10}) gthv = Qdrant(client, collection_name="gioithieuhocvien_filter", embeddings= HF_EMBEDDING) gthv_db = gthv.as_retriever(search_kwargs={'k': 10}) ttts = Qdrant(client, collection_name="thongtintuyensinh_filter", embeddings= HF_EMBEDDING) ttts_db = ttts.as_retriever(search_kwargs={'k': 10}) import pickle with open('data/sotaysinhvien_filter.pkl', 'rb') as f: sotaysinhvien = pickle.load(f) with open('data/thongtintuyensinh_filter.pkl', 'rb') as f: thongtintuyensinh = pickle.load(f) with open('data/gioithieuhocvien_filter.pkl', 'rb') as f: gioithieuhocvien = pickle.load(f) retriever_bm25_tuyensinh = BM25SRetriever.from_documents(thongtintuyensinh, k= 10, save_directory = "data/bm25s/ttts") retriever_bm25_sotay = BM25SRetriever.from_documents(sotaysinhvien, k= 10, save_directory = "data/bm25s/stsv") retriever_bm25_hocvien = BM25SRetriever.from_documents(gioithieuhocvien, k= 10, save_directory = "data/bm25s/gthv" ) # reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5) llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_3')) llm2 = ChatGroq(model_name="llama-3.1-70b-versatile", temperature=1,api_key= os.getenv('llm_api_8')) output_parser = LineListOutputParser() llm_chain = QUERY_PROMPT | llm | output_parser # messages = [ # {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"} # ] # ########################### cache = SemanticCache() another_chain = ( chitchat_prompt | llm2 | StrOutputParser()) safe_chain = ( safe_prompt | llm2 | StrOutputParser()) cache_chain = ( cache_prompt | llm2 | StrOutputParser()) # def duy_phen(): while 1: body = {} user_message = input("Nhập câu hỏi nào!: ") checker = SafetyChecker() safety_result = checker.check_safety(translate(user_message)) print("Safety check :" ,safety_result) if safety_result != 'safe' : print("UNSAFE") response = safe_chain.invoke({'meaning': f'{safety_result}'}) print(response) exit() evaluator = Evaluator(llm="llama3-70b", prompt=evaluator_intent) output = evaluator.classify_text(user_message) print(output.result) retriever = None # or assign a specific default retriever if applicable db = None # initialize db as well if it is used later in the code # print(output.result) source = None cache_result =cache.checker(user_message) if cache_result is not None: print("###Cache hit!###") response = cache_chain.invoke({"question": f'{user_message}', "content": f"{cache_result}"}) print(response) if output and output.result == 'OUT_OF_SCOPE' : print('OUT OF SCOPE') # print(body) response = another_chain.invoke({"question": f"{user_message}"}) print(response) elif output and output.result == 'ASK_QUYDINH' : print('SO TAY SINH VIEN DB') retriever = stsv_db retriever_bm25 = retriever_bm25_sotay source = stsv # db = sotaysinhvien elif output and output.result == 'ASK_HOCVIEN' : print('GIOI THIEU HOC VIEN DB') retriever = gthv_db retriever_bm25 = retriever_bm25_hocvien source = gthv # db = gioithieuhocvien elif output and output.result == 'ASK_TUYENSINH' : print('THONG TIN TUYEN SINH DB') retriever = ttts_db retriever_bm25 = retriever_bm25_tuyensinh source = ttts # db = thongtintuyensinh if retriever is not None: # retriever_multi = MultiQueryRetriever( # retriever=retriever, llm_chain=llm_chain, parser_key="lines" # ) start_time = time.time() ensemble_retriever = EnsembleRetriever( retrievers=[retriever_bm25, retriever], weights=[0.5, 0.5]) # compressor = LLMChainExtractor.from_llm(llm) # _filter = LLMChainFilter.from_llm(llm) # embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5) # compression = ContextualCompressionRetriever( # base_compressor=_filter2, base_retriever=ensemble_retriever # ) reranker = LLMListwiseRerank.from_llm( llm=llm, top_n=5 ) tailieu = ensemble_retriever.invoke(f"{user_message}") docs = reranker.compress_documents(tailieu, user_message) end_time = time.time() #################### Filter lại ở đây -> add more documents liên quan hơn ######################### # docs = compression.invoke(f"{user_message}") # print(docs) meta_data_docs = extract_metadata(docs) full_result = [] for meta_data_doc in meta_data_docs: result = search_with_filter(user_message, source, 10, meta_data_doc) for i in result: full_result.append(i) print("Context liên quan" + '\n') print(full_result) # rag_chain = ( # {"context": compression | format_docs, "question": RunnablePassthrough()} # | basic_template | llm2 | StrOutputParser() # ) result_final = get_relevant_documents(full_result, 10) context = format_docs(result_final) best_chain = ( basic_template | llm2 | StrOutputParser()) best_result = best_chain.invoke({"question": f'{user_message}', "context": f"{context}"}) print(f'Câu trả lời tối ưu nhất: {best_result}') print(f'TIME USING : {end_time -start_time}') else: print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') # duy_phen()