import os from operator import itemgetter from langchain_chroma import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.runnables import RunnablePassthrough, RunnableParallel from langchain_core.output_parsers import JsonOutputParser from langchain.prompts import PromptTemplate from lib.models import MODELS_MAP from lib.utils import format_docs, retrieve_answer, load_embeddings from lib.entities import LLMEvalResult def create_retriever(llm_name, db_path, docs, collection_name="local-rag"): text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=60) splits = text_splitter.split_documents(docs) embeddings = load_embeddings(llm_name) if not os.path.exists(db_path): vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=db_path, collection_name=collection_name) else: vectorstore = Chroma(persist_directory=db_path, embedding_function=embeddings, collection_name=collection_name) retriever = vectorstore.as_retriever() return retriever def create_qa_chain(llm, retriever, prompts_text): initial_prompt_text = prompts_text["initial_prompt"] qa_eval_prompt_text = prompts_text["evaluation_prompt"] initial_prompt = PromptTemplate( template=initial_prompt_text, input_variables=["question", "context"] ) json_parser = JsonOutputParser(pydantic_object=LLMEvalResult) qa_eval_prompt = PromptTemplate( template=qa_eval_prompt_text, input_variables=["question","answer"], partial_variables={"format_instructions": json_parser.get_format_instructions()}, ) qa_eval_prompt_with_context = PromptTemplate( template=qa_eval_prompt_text, input_variables=["question","answer","context"], partial_variables={"format_instructions": json_parser.get_format_instructions()}, ) chain = ( RunnableParallel(context = retriever | format_docs, question = RunnablePassthrough()) | RunnableParallel(answer = initial_prompt | llm | retrieve_answer, question = itemgetter("question"), context = itemgetter("context") ) | RunnableParallel(input = qa_eval_prompt, context = itemgetter("context"), answer = itemgetter("answer")) | RunnableParallel(evaluation = itemgetter("input") | llm , context = itemgetter("context"), answer = itemgetter("answer") ) | RunnableParallel(output = itemgetter("answer"), evaluation = itemgetter("evaluation") | json_parser, context = itemgetter("context")) ) return chain