import os from langchain_community.document_loaders import PyMuPDFLoader from langchain_core.documents import Document from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, ) from langchain.schema import StrOutputParser from langchain_community.vectorstores import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_anthropic import ChatAnthropic from dotenv import load_dotenv from langchain_core.output_parsers import XMLOutputParser from langchain.prompts import ChatPromptTemplate load_dotenv() # suppress grpc and glog logs for gemini os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "2" # RAG parameters CHUNK_SIZE = 1024 CHUNK_OVERLAP = CHUNK_SIZE // 8 K = 10 FETCH_K = 20 llm_model_translation = { "LLaMA 3": "llama3-70b-8192", "OpenAI GPT 4o Mini": "gpt-4o-mini", "OpenAI GPT 4o": "gpt-4o", "OpenAI GPT 4": "gpt-4-turbo", "Gemini 1.5 Pro": "gemini-1.5-pro", "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", } llm_classes = { "llama3-70b-8192": ChatGroq, "gpt-4o-mini": ChatOpenAI, "gpt-4o": ChatOpenAI, "gpt-4-turbo": ChatOpenAI, "gemini-1.5-pro": ChatGoogleGenerativeAI, "claude-3-5-sonnet-20240620": ChatAnthropic, } xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \ fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \ not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the text. At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID. Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \ justifies the text and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \ that justify the text. Use the following format for your final output: ... Here are the sources:{context}""" xml_prompt = ChatPromptTemplate.from_messages( [("system", xml_system), ("human", "{input}")] ) def format_docs_xml(docs: list[Document]) -> str: formatted = [] for i, doc in enumerate(docs): doc_str = f"""\ {doc.metadata['source']} {doc.metadata['title']} {doc.page_content} """ formatted.append(doc_str) return "\n\n" + "\n".join(formatted) + "" def citations_to_html(citations_data): if citations_data: html_output = "" return html_output return "" def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): model_name = llm_model_translation.get(model) llm_class = llm_classes.get(model_name) if not llm_class: raise ValueError(f"Model {model} not supported.") try: llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) except Exception as e: print(f"An error occurred: {e}") llm = None return llm def create_db_with_langchain(path: list[str], url_content: dict): all_docs = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") if path: for file in path: loader = PyMuPDFLoader(file) data = loader.load() # split it into chunks docs = text_splitter.split_documents(data) all_docs.extend(docs) if url_content: for url, content in url_content.items(): doc = Document(page_content=content, metadata={"source": url}) # split it into chunks docs = text_splitter.split_documents([doc]) all_docs.extend(docs) # print docs for idx, doc in enumerate(all_docs): print(f"Doc: {idx} | Length = {len(doc.page_content)}") assert len(all_docs) > 0, "No PDFs or scrapped data provided" db = Chroma.from_documents(all_docs, embedding_function) return db def generate_rag( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None db = create_db_with_langchain(path, url_content) retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) rag_prompt = hub.pull("rlm/rag-prompt") def format_docs(docs): if all(isinstance(doc, Document) for doc in docs): return "\n\n".join(doc.page_content for doc in docs) else: raise TypeError("All items in docs must be instances of Document.") docs = retriever.get_relevant_documents(topic) # formatted_docs = format_docs(docs) # rag_chain = ( # {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() # ) # return rag_chain.invoke(prompt) formatted_docs = format_docs_xml(docs) rag_chain = ( RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser() ) result = rag_chain.invoke({"input": prompt}) from pprint import pprint pprint(result) return result['cited_text'][0]['text'], citations_to_html(result['cited_text'][1]['citations']) def generate_base( prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None try: output = llm.invoke(prompt).content return output except Exception as e: print(f"An error occurred while running the model: {e}") return None def generate( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): if path or url_content: return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) else: return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)