import os from langchain_community.document_loaders import PyMuPDFLoader from langchain_core.documents import Document from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, ) from langchain.schema import StrOutputParser from langchain_community.vectorstores import Chroma from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_anthropic import ChatAnthropic from dotenv import load_dotenv from langchain_core.output_parsers import XMLOutputParser from langchain.prompts import ChatPromptTemplate load_dotenv() # suppress grpc and glog logs for gemini os.environ["GRPC_VERBOSITY"] = "ERROR" os.environ["GLOG_minloglevel"] = "2" # RAG parameters CHUNK_SIZE = 1024 CHUNK_OVERLAP = CHUNK_SIZE // 8 K = 10 FETCH_K = 20 llm_model_translation = { "LLaMA 3": "llama3-70b-8192", "OpenAI GPT 4o Mini": "gpt-4o-mini", "OpenAI GPT 4o": "gpt-4o", "OpenAI GPT 4": "gpt-4-turbo", "Gemini 1.5 Pro": "gemini-1.5-pro", "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620", } llm_classes = { "llama3-70b-8192": ChatGroq, "gpt-4o-mini": ChatOpenAI, "gpt-4o": ChatOpenAI, "gpt-4-turbo": ChatOpenAI, "gemini-1.5-pro": ChatGoogleGenerativeAI, "claude-3-5-sonnet-20240620": ChatAnthropic, } xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, fulfill all the requirements \ of the prompt and provide citations. If a chunk of the generated text does not use any of the sources (for example, \ introductions or general text), don't put a citation for that chunk and just leave citations empty. Otherwise, \ list all sources used for that chunk of the text. Don't add inline citations in the text itself. Add all citations to the separated \ citations section. Use explicit new lines in the text to show paragraph splits. \ Return a citation for every quote across all articles that justify the text. Use the following format for your final output: ... ... ... The entire text should be wrapped in one cited_text. For References section (if asked by prompt), don't add citations. For source id, give a valid integer alone without a key. Here are the sources:{context}""" xml_prompt = ChatPromptTemplate.from_messages( [("system", xml_system), ("human", "{input}")] ) def format_docs_xml(docs: list[Document]) -> str: formatted = [] for i, doc in enumerate(docs): doc_str = f"""\ {doc.metadata['source']} {doc.page_content} """ formatted.append(doc_str) return "\n\n" + "\n".join(formatted) + "" def get_doc_content(docs, id): return docs[id].page_content def process_cited_text(data, docs): # Initialize variables for the combined text and a dictionary for citations combined_text = "" citations = {} # Iterate through the cited_text list for item in data['cited_text']: chunk_text = item['chunk'][0]['text'] combined_text += chunk_text citation_ids = [] # Process the citations for the chunk if item['chunk'][1]['citations']: for c in item['chunk'][1]['citations']: if c and 'citation' in c: citation = c['citation'] if isinstance(citation, dict) and "source_id" in citation: citation = citation['source_id'] if isinstance(citation, str): try: citation_ids.append(int(citation)) except ValueError: pass # Handle cases where the string is not a valid integer if citation_ids: citation_texts = [f"<{cid}-{docs[cid].metadata['source']}>" for cid in citation_ids] combined_text += " " + " ".join(citation_texts) combined_text += "\n\n" # Store unique citations in a dictionary for citation_id in citation_ids: if citation_id not in citations: citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content} return combined_text.strip(), citations def citations_to_html(citations): # Generate the HTML for the unique citations html_content = "" for citation_id, citation_info in citations.items(): html_content += ( f"
  • Source ID: {citation_id}
    " f"Path: {citation_info['source']}
    " f"Page Content: {citation_info['content']}
  • " ) html_content += "" return html_content def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048): model_name = llm_model_translation.get(model) llm_class = llm_classes.get(model_name) if not llm_class: raise ValueError(f"Model {model} not supported.") try: llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length) except Exception as e: print(f"An error occurred: {e}") llm = None return llm def create_db_with_langchain(path: list[str], url_content: dict): all_docs = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") if path: for file in path: loader = PyMuPDFLoader(file) data = loader.load() # split it into chunks docs = text_splitter.split_documents(data) all_docs.extend(docs) if url_content: for url, content in url_content.items(): doc = Document(page_content=content, metadata={"source": url}) # split it into chunks docs = text_splitter.split_documents([doc]) all_docs.extend(docs) # print docs for idx, doc in enumerate(all_docs): print(f"Doc: {idx} | Length = {len(doc.page_content)}") assert len(all_docs) > 0, "No PDFs or scrapped data provided" db = Chroma.from_documents(all_docs, embedding_function) return db def generate_rag( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None db = create_db_with_langchain(path, url_content) retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K}) rag_prompt = hub.pull("rlm/rag-prompt") def format_docs(docs): if all(isinstance(doc, Document) for doc in docs): return "\n\n".join(doc.page_content for doc in docs) else: raise TypeError("All items in docs must be instances of Document.") docs = retriever.get_relevant_documents(topic) formatted_docs = format_docs_xml(docs) rag_chain = ( RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser() ) result = rag_chain.invoke({"input": prompt}) text, citations = process_cited_text(result, docs) return text, citations def generate_base( prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message="" ): llm = load_llm(model, api_key, temperature, max_length) if llm is None: print("Failed to load LLM. Aborting operation.") return None try: output = llm.invoke(prompt).content return output except Exception as e: print(f"An error occurred while running the model: {e}") return None def generate( prompt: str, topic: str, model: str, url_content: dict, path: list[str], temperature: float = 1.0, max_length: int = 2048, api_key: str = "", sys_message="", ): if path or url_content: return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message) else: return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)