Spaces:
Runtime error
Runtime error
File size: 4,862 Bytes
ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b 6cca36c ed4ca74 5ebf50b ed4ca74 6cca36c bc4b9a8 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 fe32e45 5ebf50b 80a9356 5ebf50b 27aa3e6 5ebf50b 80a9356 6cca36c 4a36e50 6cca36c 5ebf50b 8133318 5ebf50b 6cca36c 5ebf50b ed4ca74 38298ad 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b 6cca36c 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from pathlib import Path
from sentence_transformers import SentenceTransformer
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
class SentenceTransformerEmbedding:
def __init__(self, model_name):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
return self.model.encode(texts, convert_to_tensor=True).tolist() # Convert to list for compatibility
def __call__(self, texts):
return self.embed_documents(texts) # Make it callable
@cl.on_chat_start
async def on_chat_start():
model = ChatOpenAI(streaming=True)
# Load documents
ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
RAG_PROMPT = """\
Given a provided context and question, you must answer the question based only on context.
Context: {context}
Question: {question}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
sentence_text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ".", "!", "?"]
)
def metadata_generator(document, name, splitter):
collection = splitter.split_documents(document)
for doc in collection:
doc.metadata["source"] = name
return collection
sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)
sentence_combined_documents = sentence_framework + sentence_blueprint
# Initialize the custom embedding class
embedding_model = SentenceTransformerEmbedding('Cheselle/finetuned-arctic-sentence')
# Create the Qdrant vector store using the custom embedding model
sentence_vectorstore = Qdrant.from_documents(
documents=sentence_combined_documents,
embedding=embedding_model, # Ensure this is an instance
location=":memory:",
collection_name="AI Policy"
)
sentence_retriever = sentence_vectorstore.as_retriever()
# Check if retriever is initialized correctly
if sentence_retriever is None:
raise ValueError("Retriever is not initialized correctly.")
# Set the retriever and prompt into session for reuse
cl.user_session.set("runnable", model)
cl.user_session.set("retriever", sentence_retriever)
cl.user_session.set("prompt_template", rag_prompt)
@cl.on_message
async def on_message(message: cl.Message):
# Get the stored model, retriever, and prompt
model = cast(ChatOpenAI, cl.user_session.get("runnable"))
retriever = cl.user_session.get("retriever")
prompt_template = cl.user_session.get("prompt_template")
# Log the message content
print(f"Received message: {message.content}")
# Retrieve relevant context from documents based on the user's message
if retriever is None:
print("Retriever is not available.")
await cl.Message(content="Sorry, the retriever is not initialized.").send()
return
relevant_docs = retriever.get_relevant_documents(message.content)
print(f"Retrieved {len(relevant_docs)} documents.")
if not relevant_docs:
print("No relevant documents found.")
await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
return
context = "\n\n".join([doc.page_content for doc in relevant_docs])
# Log the context to check
print(f"Context: {context}")
# Construct the final RAG prompt
final_prompt = prompt_template.format(context=context, question=message.content)
print(f"Final prompt: {final_prompt}")
# Initialize a streaming message
msg = cl.Message(content="")
# Stream the response from the model
async for chunk in model.astream(
final_prompt,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk.content)
await msg.send()
|