|
import streamlit as st |
|
import re |
|
import os |
|
|
|
|
|
sidebar_bg_style = """ |
|
<style> |
|
[data-testid="stSidebar"] { |
|
background: linear-gradient(135deg, #ffafbd, #ffc3a0, #2193b0, #6dd5ed); |
|
} |
|
</style> |
|
""" |
|
st.markdown(sidebar_bg_style, unsafe_allow_html=True) |
|
|
|
|
|
main_bg_style = """ |
|
<style> |
|
.main .block-container { |
|
background: linear-gradient(135deg, #ff9a9e, #fad0c4, #fbc2eb, #a18cd1); |
|
padding: 2rem; |
|
} |
|
.css-18e3th9 { |
|
background: linear-gradient(135deg, #ff9a9e, #fad0c4, #fbc2eb, #a18cd1); |
|
} |
|
</style> |
|
""" |
|
st.markdown(main_bg_style, unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title("Settings") |
|
|
|
|
|
url_input = st.sidebar.text_input("Enter Blog Post URL", placeholder="e.g., https://example.com/blog", help="Paste the full URL of the blog post you want to retrieve data from") |
|
|
|
|
|
if url_input: |
|
if re.match(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+", url_input): |
|
st.sidebar.markdown('<p style="color:green; font-weight:bold;">URL is correctly entered</p>', unsafe_allow_html=True) |
|
else: |
|
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Invalid URL, please enter a valid one</p>', unsafe_allow_html=True) |
|
|
|
|
|
use_preprovided_keys = st.sidebar.checkbox("Use pre-provided API keys") |
|
|
|
|
|
if not use_preprovided_keys: |
|
api_key_1 = st.sidebar.text_input("Enter LangChain API Key", type="password", placeholder="Enter your LangChain API Key", help="Please enter a valid LangChain API key here") |
|
api_key_2 = st.sidebar.text_input("Enter Groq API Key", type="password", placeholder="Enter your Groq API Key", help="Please enter your Groq API key here") |
|
else: |
|
api_key_1 = "your-preprovided-langchain-api-key" |
|
api_key_2 = "your-preprovided-groq-api-key" |
|
st.sidebar.markdown('<p style="color:blue; font-weight:bold;">Using pre-provided API keys</p>', unsafe_allow_html=True) |
|
|
|
|
|
if st.sidebar.button("Submit API Keys"): |
|
if use_preprovided_keys or (api_key_1 and api_key_2): |
|
os.environ["LANGCHAIN_API_KEY"] = api_key_1 |
|
os.environ["GROQ_API_KEY"] = api_key_2 |
|
st.sidebar.markdown('<p style="color:green; font-weight:bold;">API keys are set</p>', unsafe_allow_html=True) |
|
else: |
|
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Please fill in both API keys or select the option to use pre-provided keys</p>', unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(""" |
|
<marquee behavior="scroll" direction="left" scrollamount="10"> |
|
<p style='font-size:24px; color:#FF5733; font-weight:bold;'> |
|
Created by: <a href="https://www.linkedin.com/in/datascientisthameshraj/" target="_blank" style="color:#1E90FF; text-decoration:none;">Engr. Hamesh Raj</a> |
|
</p> |
|
</marquee> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown('<h1 style="color:#4CAF50; font-weight:bold;">🤖 Chatbot with URL-based Document Retrieval</h1>', unsafe_allow_html=True) |
|
|
|
|
|
query = st.text_input("Ask a question based on the blog post", placeholder="Type your question here...", help="Enter a question related to the content of the blog post") |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
|
|
if st.button("Submit Query"): |
|
if query and url_input: |
|
|
|
loader = WebBaseLoader( |
|
web_paths=(url_input,), |
|
bs_kwargs=dict( |
|
parse_only=bs4.SoupStrainer() |
|
), |
|
) |
|
docs = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
splits = text_splitter.split_documents(docs) |
|
|
|
|
|
class SentenceTransformerEmbedding: |
|
def __init__(self, model_name): |
|
self.model = SentenceTransformer(model_name) |
|
|
|
def embed_documents(self, texts): |
|
embeddings = self.model.encode(texts, convert_to_tensor=True) |
|
if isinstance(embeddings, torch.Tensor): |
|
return embeddings.cpu().detach().numpy().tolist() |
|
return embeddings |
|
|
|
def embed_query(self, query): |
|
embedding = self.model.encode([query], convert_to_tensor=True) |
|
if isinstance(embedding, torch.Tensor): |
|
return embedding.cpu().detach().numpy().tolist()[0] |
|
return embedding[0] |
|
|
|
|
|
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2') |
|
|
|
|
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model) |
|
|
|
|
|
retriever = vectorstore.as_retriever() |
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
rag_chain = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| ChatGroq(model="llama3-8b-8192") |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
result = rag_chain.invoke(query) |
|
|
|
|
|
st.session_state['chat_history'].append((query, result)) |
|
|
|
|
|
for q, r in st.session_state['chat_history']: |
|
st.write(f"**User:** {q}") |
|
st.write(f"**Bot:** {r}") |