Spaces:
Sleeping
Sleeping
import re | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import StrOutputParser | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_community.vectorstores import Qdrant | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from langchain_core.documents import Document | |
from operator import itemgetter | |
import os | |
from dotenv import load_dotenv | |
import chainlit as cl | |
from langchain.embeddings.base import Embeddings | |
from sentence_transformers import SentenceTransformer | |
# Load environment variables | |
load_dotenv() | |
# Custom wrapper for SentenceTransformer to work with Langchain | |
class LangchainSentenceTransformerEmbeddings(Embeddings): | |
def __init__(self, model_name: str): | |
self.model = SentenceTransformer(model_name) | |
def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
# Encode the documents using SentenceTransformer's encode method | |
return self.model.encode(texts) | |
def embed_query(self, text: str) -> list[float]: | |
# Encode a single query using SentenceTransformer's encode method | |
return self.model.encode([text])[0] | |
# Initialize the custom embedding model | |
embedding_model = LangchainSentenceTransformerEmbeddings("Cheselle/finetuned-arctic-sentence") | |
# Load the documents using PyMuPDFLoader | |
ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load() | |
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load() | |
# Metadata generator function to add metadata to documents | |
def metadata_generator(document, name): | |
fixed_text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=100, | |
separators=["\n\n", "\n", ".", "!", "?"] | |
) | |
collection = fixed_text_splitter.split_documents(document) | |
for doc in collection: | |
doc.metadata["source"] = name | |
return collection | |
# Generate metadata for the loaded documents | |
recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework") | |
recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint") | |
combined_documents = recursive_framework_document + recursive_blueprint_document | |
# Combine the content of the documents | |
ai_framework_text = "".join([doc.page_content for doc in ai_framework_document]) | |
ai_blueprint_text = "".join([doc.page_content for doc in ai_blueprint_document]) | |
# Pass the custom embedding model to Qdrant to create a vectorstore | |
vectorstore = Qdrant.from_documents( | |
documents=combined_documents, # List of documents | |
embedding=embedding_model, # Custom Langchain wrapper for SentenceTransformer | |
location=":memory:", | |
collection_name="ai_policy" | |
) | |
# Set up the retriever | |
retriever = vectorstore.as_retriever() | |
# LLM configuration | |
llm = ChatOpenAI(model="gpt-4o-mini") | |
# Define the RAG (Retrieval-Augmented Generation) prompt template | |
RAG_PROMPT = """\ | |
You are an AI Policy Expert. | |
Given a provided context and question, you must answer the question based only on context. | |
Think through your answer carefully and step by step. | |
Context: {context} | |
Question: {question} | |
""" | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
# Define the retrieval-augmented QA chain | |
retrieval_augmented_qa_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": rag_prompt | llm, "context": itemgetter("context")} | |
) | |
# Chainlit event handler for receiving messages | |
async def handle_message(message): | |
try: | |
# Process the incoming question using the RAG chain | |
result = retrieval_augmented_qa_chain.invoke({"question": message.content}) | |
# Create a new message for the response | |
response_message = cl.Message(content=result["response"].content) | |
# Send the response back to the user | |
await response_message.send() | |
except Exception as e: | |
# Handle any exception and log it or send a response back to the user | |
error_message = cl.Message(content=f"An error occurred: {str(e)}") | |
await error_message.send() | |
print(f"Error occurred: {e}") | |
# Run the Chainlit server | |
if __name__ == "__main__": | |
try: | |
cl.run() | |
except Exception as e: | |
print(f"Server error occurred: {e}") | |