Cheselle's picture
Update app.py
dca2082 verified
raw
history blame
5.57 kB
import re
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.documents import Document
from operator import itemgetter
import os
from dotenv import load_dotenv
import chainlit as cl
load_dotenv()
ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load()
ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
def metadata_generator(document, name):
fixed_text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", ".", "!", "?"]
)
collection = fixed_text_splitter.split_documents(document)
for doc in collection:
doc.metadata["source"] = name
return collection
recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework")
recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint")
combined_documents = recursive_framework_document + recursive_blueprint_document
"""
#from transformers import AutoTokenizer, AutoModel
#import torch
#embedding = AutoModel.from_pretrained("Cheselle/finetuned-arctic-sentence")
#tokenizer = AutoTokenizer.from_pretrained("Cheselle/finetuned-arctic-sentence")
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("Cheselle/finetuned-arctic-sentence")
embeddings = embedding_model.encode(ai_framework_document + ai_blueprint_document)
vectorstore = Qdrant.from_documents(
documents=combined_documents,
embedding=embeddings,
location=":memory:",
collection_name="ai_policy"
)
"""
from transformers import AutoTokenizer, AutoModel
import torch
from qdrant_client import QdrantClient
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Cheselle/finetuned-arctic-sentence")
model = AutoModel.from_pretrained("Cheselle/finetuned-arctic-sentence")
# Define a wrapper function for embedding documents
def embed(documents):
inputs = tokenizer(documents, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy() # Return embeddings
# Initialize Qdrant client (in-memory for testing)
qdrant_client = QdrantClient(":memory:")
# Create the Qdrant collection
qdrant_client.recreate_collection(
collection_name="ai_policy",
vectors_config={"size": 768, "distance": "Cosine"} # Adjust size based on embedding dimensions
)
# Create vectorstore (with embedding function)
vectorstore = qdrant_client.upsert(
collection_name="ai_policy",
points=[
{"id": i, "vector": embed([doc])[0], "payload": {"document": doc}}
for i, doc in enumerate(combined_documents)
]
)
retriever = vectorstore.as_retriever()
## Generation LLM
llm = ChatOpenAI(model="gpt-4o-mini")
RAG_PROMPT = """\
You are an AI Policy Expert.
Given a provided context and question, you must answer the question based only on context.
Think through your answer carefully and step by step.
Context: {context}
Question: {question}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
retrieval_augmented_qa_chain = (
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
# "question" : populated by getting the value of the "question" key
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
# by getting the value of the "context" key from the previous step
| RunnablePassthrough.assign(context=itemgetter("context"))
# "response" : the "context" and "question" values are used to format our prompt object and then piped
# into the LLM and stored in a key called "response"
# "context" : populated by getting the value of the "context" key from the previous step
| {"response": rag_prompt | llm, "context": itemgetter("context")}
)
#alt_rag_chain.invoke({"question" : "What is the AI framework all about?"})
@cl.on_message
async def handle_message(message):
try:
# Process the incoming question using the RAG chain
result = retrieval_augmented_qa_chain.invoke({"question": message.content})
# Create a new message for the response
response_message = cl.Message(content=result["response"].content)
# Send the response back to the user
await response_message.send()
except Exception as e:
# Handle any exception and log it or send a response back to the user
error_message = cl.Message(content=f"An error occurred: {str(e)}")
await error_message.send()
print(f"Error occurred: {e}")
# Run the ChainLit server
if __name__ == "__main__":
try:
cl.run()
except Exception as e:
print(f"Server error occurred: {e}")