import os import time import ast from typing import List, Optional from pydantic import BaseModel # from semantic_router.route import Route from Router.router import Evaluator from semantic_router.samples import rag_sample, chitchatSample from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message from blueprints.rag_utils import format_docs from blueprints.prompts import accurate_rag_prompt, QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt from BM25 import BM25SRetriever from SafetyChecker import SafetyChecker from langchain.retrievers import EnsembleRetriever from BM25 import BM25SRetriever from semantic_cache.main import SemanticCache from sentence_transformers import SentenceTransformer # from database_Routing import DB_Router from langchain.retrievers.multi_query import MultiQueryRetriever import cohere from langchain_core.output_parsers import BaseOutputParser from langchain_cohere import CohereRerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_groq import ChatGroq from langchain_core.runnables import RunnablePassthrough # import logging # logging.basicConfig( # level=print, # format='%(levelname)s - %(message)s') from dotenv import load_dotenv load_dotenv() qdrant_url = os.getenv('URL_QDRANT') qdrant_api = os.getenv('API_QDRANT') os.environ["COHERE_API_KEY"] #####Embedding model###### class LineListOutputParser(BaseOutputParser[List[str]]): """Output parser for a list of lines.""" def parse(self, text: str) -> List[str]: lines = text.strip().split("\n") return list(filter(None, lines)) # Remove empty lines class Pipeline: class Valves(BaseModel): # List target pipeline ids (models) that this filter will be connected to. # If you want to connect this filter to all pipelines, you can set pipelines to ["*"] pipelines: List[str] = [] # Assign a priority level to the filter pipeline. # The priority level determines the order in which the filter pipelines are executed. # The lower the number, the higher the priority. priority: int = 0 # Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc. pass def __init__(self): self.type = "filter" self.name = "Filter" self.embedding = None self.route = None self.stsv_db = None self.gthv_db = None self.ttts_db = None self.reranker = None self.valves = self.Valves(**{"pipelines": ["*"]}) pass def split_context(self, context): split_index = context.find("User question") system_prompt = context[:split_index].strip() user_question = context[split_index:].strip() user_split_index = user_question.find("") f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:]) return f_system_prompt async def on_startup(self): # This function is called when the server is started. print(f"on_startup:{__name__}") from typing import List from langchain_community.vectorstores import Qdrant from langchain_huggingface import HuggingFaceEmbeddings self.embedding = SentenceTransformer("dangvantuan/vietnamese-embedding") HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding") from qdrant_client import QdrantClient from langchain_community.vectorstores import Qdrant # client = QdrantClient( # qdrant_url, # api_key=qdrant_api # ) client = QdrantClient(url="http://localhost:6333") gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING) self.gthv_db = gthv.as_retriever() stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING) self.stsv_db = stsv.as_retriever() ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING) self.ttts_db = ttts.as_retriever() import pickle with open('data/thongtintuyensinh.pkl', 'rb') as f: self.thongtintuyensinh = pickle.load(f) with open('data/sotaysinhvien.pkl', 'rb') as f: self.sotaysinhvien = pickle.load(f) with open('data/gioithieuhocvien.pkl', 'rb') as f: self.gioithieuhocvien = pickle.load(f) self.retriever_bm25_tuyensinh = BM25SRetriever.from_documents(self.thongtintuyensinh, k= 5, activate_numba = True) self.retriever_bm25_sotay = BM25SRetriever.from_documents(self.sotaysinhvien, k= 5, activate_numba = True) self.retriever_bm25_hocvien = BM25SRetriever.from_documents(self.gioithieuhocvien, k= 5, activate_numba = True) self.cache = SemanticCache() self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5) llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5')) output_parser = LineListOutputParser() self.llm_chain = QUERY_PROMPT | llm | output_parser pass async def on_shutdown(self): # This function is called when the server is stopped. print(f"on_shutdown:{__name__}") pass def get_last_assistant_message(self, messages: List[dict]) -> str: for message in reversed(messages): if message["role"] == "assistant": if isinstance(message["content"], list): for item in message["content"]: if item["type"] == "text": return item["text"] return message["content"] return "" def add_or_update_system_message(self,content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list :param msg: The message to be added or appended. :param messages: The list of message dictionaries. :return: The updated list of message dictionaries. """ if messages and messages[0].get("role") == "system": messages[0]["content"] += f"{content}\n" else: # Insert at the beginning messages.insert(0, {"role": "system", "content": content}) return messages def add_messages(self,content: str, messages: List[dict]): messages.insert(0, {"role": "system", "content": content}) return messages cache_hit = False async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: messages = body.get("messages", []) print(messages) user_message = get_last_user_message(messages) print(user_message) #####guard##### checker = SafetyChecker() safety_result = checker.check_safety(user_message) if safety_result != 'safe' : print("Safety check :" ,safety_result) construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {safety_result}" body["messages"] = self.add_messages( construct_msg, messages) print(body) return body #####Router##### # MTA_ROUTE_NAME = 'mta' # CHITCHAT_ROUTE_NAME = 'chitchat' # mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample) # chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample) # router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute]) # guidedRoute = router.guide(user_message)[1] # print("Semantic Router :", guidedRoute) cache_result = self.cache.checker(user_message) if cache_result is not None: print("###Cache hit!###") self.cache_hit = True construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt \n\n : {cache_result}" body["messages"] = self.add_or_update_system_message( construct_msg, messages) print(body) return body self.cache_hit = False print("No cache found! Generation continue") evaluator = Evaluator(llm="llama3-8b", prompt=evaluator_intent) output = evaluator.classify_text(user_message) retriever = None print(f'Câu hỏi người dùng: {user_message}') # print(output.result) if output and output.result == 'OUT_OF_SCOPE' : print('OUT OF SCOPE') construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {chitchat_prompt}" body["messages"] = self.add_or_update_system_message( construct_msg, messages) print(body) return body elif output and output.result == 'ASK_QUYDINH' : print('SO TAY SINH VIEN DB') retriever = self.stsv_db retriever_bm25 = self.retriever_bm25_sotay # db = self.sotaysinhvien elif output and output.result == 'ASK_HOCVIEN' : print('HOC VIEN DB') retriever = self.gthv_db retriever_bm25 = self.retriever_bm25_hocvien # db = self.gioithieuhocvien elif output and output.result == 'ASK_TUYENSINH' : print('THONG TIN TUYEN SINH DB') retriever = self.ttts_db retriever_bm25 = self.retriever_bm25_tuyensinh # db = self.thongtintuyensinh if retriever is not None: retriever_multi = MultiQueryRetriever( retriever=retriever, llm_chain=self.llm_chain, parser_key="lines" ) # retriever_bm25 = BM25SRetriever.from_documents(db, k= 5, activate_numba = True) ensemble_retriever = EnsembleRetriever( retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5]) compression = ContextualCompressionRetriever( base_compressor=self.reranker, base_retriever=ensemble_retriever ) rag_chain = ( {"context": compression | format_docs, "question": RunnablePassthrough()} | basic_template ) # last_asisstant = self.get_last_assistant_message(messages) # print("###################### last asisstant") # print(last_asisstant) #rag_prompt = rag_chain.invoke(user_message + "\n" + last_asisstant).text rag_prompt = rag_chain.invoke(user_message ).text system_message = self.split_context(rag_prompt) body["messages"] = self.add_or_update_system_message( system_message, messages ) print(body) # self.cache.add_to_cache(question, response_text) return body else: print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') async def outlet(self, body : dict , user : Optional[dict]= None) -> dict : print("##########################") messages = body.get("messages", []) # print(messages) user_message = get_last_user_message(messages) print(user_message) print("########### Câu hỏi vừa hỏi #################") # output_list = ast.literal_eval(user_message) # print(output_list) # print(output_list[-2]['content']) # print(output_list[-1]['content']) # print(f"outlet:{__name__}") # print(f'##### Cache hit = {self.cache_hit}') # if body and self.cache_hit == False: # print(body['messages'][-2]['content']) # print(body['messages'][-1]['content']) # self.cache.add_to_cache(body['messages'][-2]['content'], body['messages'][-1]['content']) print(f"Outlet Body Input: {body}") return body