BlogRetrievalQA / app.py
AreesaAshfaq's picture
Update app.py
c469b78 verified
raw
history blame
3.44 kB
import streamlit as st
from sentence_transformers import SentenceTransformer
from langchain import hub
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_text_splitters import RecursiveCharacterTextSplitter
import bs4
import torch
import getpass
# Prompt the user to enter their Langchain API key
api_key_langchain = st.text_input("Enter your LANGCHAIN_API_KEY", type="password")
# Check if the API key has been provided
if api_key_langchain:
# Use the API key in your app
st.write("LangChain API Key is set.")
else:
st.write("Please enter your LangChain API key.")
# Prompt the user to enter their Groq API key
api_key_Groq = st.text_input("Enter your Groq_API_KEY", type="password")
# Check if the Groq API key has been provided
if api_key_Groq:
# Use the Groq API key in your app
st.write("Groq API Key is set.")
else:
st.write("Please enter your Groq API key.")
# Initialize LangChain client (hypothetical example)
#lc_client = Client(api_key=LANGCHAIN_API_KEY)
GROQ_API_KEY = api_key_Groq
from langchain_groq import ChatGroq
llm = ChatGroq(model="llama3-8b-8192")
# Define the embedding class
class SentenceTransformerEmbedding:
def __init__(self, model_name):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
embeddings = self.model.encode(texts, convert_to_tensor=True)
if isinstance(embeddings, torch.Tensor):
return embeddings.cpu().detach().numpy().tolist() # Convert tensor to list
return embeddings
def embed_query(self, query):
embedding = self.model.encode([query], convert_to_tensor=True)
if isinstance(embedding, torch.Tensor):
return embedding.cpu().detach().numpy().tolist()[0] # Convert tensor to list
return embedding[0]
# Initialize the embedding class
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
# Load, chunk, and index the contents of the blog
def load_data():
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
bs_kwargs=dict(
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
),
)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
return vectorstore
vectorstore = load_data()
# Streamlit UI
st.title("Blog Retrieval and Question Answering")
question = st.text_input("Enter your question:")
if question:
retriever = vectorstore.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm # Replace with your LLM or appropriate function if needed
| StrOutputParser()
)
# Example invocation
try:
result = rag_chain.invoke(question)
st.write("Answer:", result)
except Exception as e:
st.error(f"An error occurred: {e}")