Spaces:
Runtime error
Runtime error
File size: 4,212 Bytes
ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 fe32e45 5ebf50b 80a9356 5ebf50b 27aa3e6 5ebf50b 80a9356 5ebf50b 4a36e50 5ebf50b 8133318 5ebf50b ed4ca74 38298ad 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b ed4ca74 5ebf50b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from typing import cast
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from pathlib import Path
from sentence_transformers import SentenceTransformer # Ensure this import is correct
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
@cl.on_chat_start
async def on_chat_start():
model = ChatOpenAI(streaming=True)
# Load documents
ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
RAG_PROMPT = """\
Given a provided context and question, you must answer the question based only on context.
Context: {context}
Question: {question}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
sentence_text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ".", "!", "?"]
)
def metadata_generator(document, name, splitter):
collection = splitter.split_documents(document)
for doc in collection:
doc.metadata["source"] = name
return collection
sentence_framework = metadata_generator(ai_framework_document, "AI Framework", sentence_text_splitter)
sentence_blueprint = metadata_generator(ai_blueprint_document, "AI Blueprint", sentence_text_splitter)
sentence_combined_documents = sentence_framework + sentence_blueprint
# Initialize the SentenceTransformer model properly
embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')
# Create the Qdrant vector store using the initialized embedding model
sentence_vectorstore = Qdrant.from_documents(
documents=sentence_combined_documents,
embedding=embedding_model, # Ensure this is an instance
location=":memory:",
collection_name="AI Policy"
)
sentence_retriever = sentence_vectorstore.as_retriever()
# Set the retriever and prompt into session for reuse
cl.user_session.set("runnable", model)
cl.user_session.set("retriever", sentence_retriever)
cl.user_session.set("prompt_template", rag_prompt)
@cl.on_message
async def on_message(message: cl.Message):
# Get the stored model, retriever, and prompt
model = cast(ChatOpenAI, cl.user_session.get("runnable"))
retriever = cl.user_session.get("retriever")
prompt_template = cl.user_session.get("prompt_template")
# Log the message content
print(f"Received message: {message.content}")
# Retrieve relevant context from documents based on the user's message
relevant_docs = retriever.get_relevant_documents(message.content)
print(f"Retrieved {len(relevant_docs)} documents.")
if not relevant_docs:
print("No relevant documents found.")
await cl.Message(content="Sorry, I couldn't find any relevant documents.").send()
return
context = "\n\n".join([doc.page_content for doc in relevant_docs])
# Log the context to check
print(f"Context: {context}")
# Construct the final RAG prompt
final_prompt = prompt_template.format(context=context, question=message.content)
print(f"Final prompt: {final_prompt}")
# Initialize a streaming message
msg = cl.Message(content="")
# Stream the response from the model
async for chunk in model.astream(
final_prompt,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk.content)
await msg.send()
|