|
import os |
|
import time |
|
from typing import List |
|
from qdrant_client import QdrantClient, models |
|
from langchain_core.documents import Document |
|
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings |
|
from blueprints.rag_utils import format_docs |
|
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt |
|
from langchain.retrievers import EnsembleRetriever |
|
from BM25 import BM25SRetriever |
|
from langchain_mistralai import ChatMistralAI |
|
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.output_parsers import BaseOutputParser |
|
|
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
|
|
import time |
|
from qdrant_client import QdrantClient |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain.retrievers.document_compressors import LLMChainFilter |
|
from langchain.retrievers.document_compressors import EmbeddingsFilter |
|
from langchain.retrievers.document_compressors import LLMListwiseRerank |
|
from langchain_openai import OpenAIEmbeddings |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small') |
|
|
|
class LineListOutputParser(BaseOutputParser[List[str]]): |
|
"""Output parser for a list of lines.""" |
|
|
|
def parse(self, text: str) -> List[str]: |
|
lines = text.strip().split("\n") |
|
return list(filter(None, lines)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
client = QdrantClient(url="http://localhost:6333") |
|
|
|
stsv = Qdrant(client, collection_name="eval_collection2", embeddings= HF_EMBEDDING) |
|
retriever = stsv.as_retriever(search_kwargs={'k': 3}) |
|
|
|
import pickle |
|
|
|
with open('/home/justtuananh/AI4TUAN/DOAN2024/offical/pipelines/documents_new.pkl', 'rb') as f: |
|
sotaysinhvien = pickle.load(f) |
|
|
|
retriever_bm25 = BM25SRetriever.from_documents(sotaysinhvien, k= 5, activate_numba = True) |
|
|
|
|
|
llm = ChatMistralAI( |
|
model="mistral-large-2407", |
|
temperature=0, |
|
max_retries=2, |
|
) |
|
|
|
|
|
output_parser = LineListOutputParser() |
|
llm_chain = QUERY_PROMPT | llm | output_parser |
|
|
|
messages = [ |
|
{"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"} |
|
] |
|
|
|
def duy_phen(): |
|
|
|
user_message = input("Nhập câu hỏi của bạn: ") |
|
start_time = time.time() |
|
|
|
if retriever is not None: |
|
retriever_multi = MultiQueryRetriever( |
|
retriever=retriever, llm_chain=llm_chain, parser_key="lines" |
|
) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5]) |
|
|
|
compressor = LLMChainExtractor.from_llm(llm) |
|
|
|
_filter2 = LLMListwiseRerank.from_llm(llm, top_n=5) |
|
embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5) |
|
compression = ContextualCompressionRetriever( |
|
base_compressor=_filter2, base_retriever=ensemble_retriever |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(compression.invoke(f"{user_message}")) |
|
end_time = time.time() |
|
print(f'TIME USING : {end_time-start_time}') |
|
|
|
else: |
|
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') |
|
|
|
duy_phen() |
|
|