import re from langchain_openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI from langchain_openai.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import StrOutputParser from langchain_community.document_loaders import PyMuPDFLoader from langchain_community.vectorstores import Qdrant from langchain_core.runnables import RunnablePassthrough, RunnableParallel from langchain_core.documents import Document from operator import itemgetter import os from dotenv import load_dotenv import chainlit as cl from langchain.embeddings.base import Embeddings from sentence_transformers import SentenceTransformer # Load environment variables load_dotenv() # Custom wrapper for SentenceTransformer to work with Langchain class LangchainSentenceTransformerEmbeddings(Embeddings): def __init__(self, model_name: str): self.model = SentenceTransformer(model_name) def embed_documents(self, texts: list[str]) -> list[list[float]]: # Encode the documents using SentenceTransformer's encode method return self.model.encode(texts) def embed_query(self, text: str) -> list[float]: # Encode a single query using SentenceTransformer's encode method return self.model.encode([text])[0] # Initialize the custom embedding model embedding_model = LangchainSentenceTransformerEmbeddings("Cheselle/finetuned-arctic-sentence") # Load the documents using PyMuPDFLoader ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load() ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load() # Metadata generator function to add metadata to documents def metadata_generator(document, name): fixed_text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", ".", "!", "?"] ) collection = fixed_text_splitter.split_documents(document) for doc in collection: doc.metadata["source"] = name return collection # Generate metadata for the loaded documents recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework") recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint") combined_documents = recursive_framework_document + recursive_blueprint_document # Combine the content of the documents ai_framework_text = "".join([doc.page_content for doc in ai_framework_document]) ai_blueprint_text = "".join([doc.page_content for doc in ai_blueprint_document]) # Pass the custom embedding model to Qdrant to create a vectorstore vectorstore = Qdrant.from_documents( documents=combined_documents, # List of documents embedding=embedding_model, # Custom Langchain wrapper for SentenceTransformer location=":memory:", collection_name="ai_policy" ) # Set up the retriever retriever = vectorstore.as_retriever() # LLM configuration llm = ChatOpenAI(model="gpt-4o-mini") # Define the RAG (Retrieval-Augmented Generation) prompt template RAG_PROMPT = """\ You are an AI Policy Expert. Given a provided context and question, you must answer the question based only on context. Think through your answer carefully and step by step. Context: {context} Question: {question} """ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) # Define the retrieval-augmented QA chain retrieval_augmented_qa_chain = ( {"context": itemgetter("question") | retriever, "question": itemgetter("question")} | RunnablePassthrough.assign(context=itemgetter("context")) | {"response": rag_prompt | llm, "context": itemgetter("context")} ) # Chainlit event handler for receiving messages @cl.on_message async def handle_message(message): try: # Process the incoming question using the RAG chain result = retrieval_augmented_qa_chain.invoke({"question": message.content}) # Create a new message for the response response_message = cl.Message(content=result["response"].content) # Send the response back to the user await response_message.send() except Exception as e: # Handle any exception and log it or send a response back to the user error_message = cl.Message(content=f"An error occurred: {str(e)}") await error_message.send() print(f"Error occurred: {e}") # Run the Chainlit server if __name__ == "__main__": try: cl.run() except Exception as e: print(f"Server error occurred: {e}")