Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import re
|
3 |
+
from langchain_groq import ChatGroq
|
4 |
+
from langchain import hub
|
5 |
+
from langchain_chroma import Chroma
|
6 |
+
from langchain_community.document_loaders import WebBaseLoader
|
7 |
+
from langchain_core.output_parsers import StrOutputParser
|
8 |
+
from langchain_core.runnables import RunnablePassthrough
|
9 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
import bs4
|
12 |
+
import torch
|
13 |
+
import os
|
14 |
+
|
15 |
+
# Sidebar Style with Multicolored Background
|
16 |
+
sidebar_bg_style = """
|
17 |
+
<style>
|
18 |
+
[data-testid="stSidebar"] {
|
19 |
+
background: linear-gradient(135deg, #ffafbd, #ffc3a0, #2193b0, #6dd5ed);
|
20 |
+
}
|
21 |
+
</style>
|
22 |
+
"""
|
23 |
+
st.markdown(sidebar_bg_style, unsafe_allow_html=True)
|
24 |
+
|
25 |
+
# Sidebar: Input for URL and API keys
|
26 |
+
st.sidebar.title("Settings")
|
27 |
+
|
28 |
+
# Input field for entering URL dynamically with placeholder and help text
|
29 |
+
url_input = st.sidebar.text_input("Enter Blog Post URL", placeholder="e.g., https://example.com/blog", help="Paste the full URL of the blog post you want to retrieve data from")
|
30 |
+
|
31 |
+
# Validate the URL and show a success message when correct
|
32 |
+
if url_input:
|
33 |
+
if re.match(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+", url_input):
|
34 |
+
st.sidebar.markdown('<p style="color:green; font-weight:bold;">URL is correctly entered</p>', unsafe_allow_html=True)
|
35 |
+
else:
|
36 |
+
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Invalid URL, please enter a valid one</p>', unsafe_allow_html=True)
|
37 |
+
|
38 |
+
# Input fields for API keys with placeholders and helper text
|
39 |
+
api_key_1 = st.sidebar.text_input("Enter LangChain API Key", type="password", placeholder="Enter your LangChain API Key", help="Please enter a valid LangChain API key here")
|
40 |
+
api_key_2 = st.sidebar.text_input("Enter Groq API Key", type="password", placeholder="Enter your Groq API Key", help="Please enter your Groq API key here")
|
41 |
+
|
42 |
+
# Submit button for API keys with a success/warning message
|
43 |
+
if st.sidebar.button("Submit API Keys"):
|
44 |
+
if api_key_1 and api_key_2:
|
45 |
+
os.environ["LANGCHAIN_API_KEY"] = api_key_1
|
46 |
+
os.environ["GROQ_API_KEY"] = api_key_2
|
47 |
+
st.sidebar.markdown('<p style="color:green; font-weight:bold;">Both API keys are entered</p>', unsafe_allow_html=True)
|
48 |
+
else:
|
49 |
+
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Please fill in both API keys</p>', unsafe_allow_html=True)
|
50 |
+
|
51 |
+
# Main Section with Multicolored Background and Chatbot Title
|
52 |
+
main_bg_style = """
|
53 |
+
<style>
|
54 |
+
body {
|
55 |
+
background: linear-gradient(135deg, #ff9a9e, #fad0c4, #fbc2eb, #a18cd1);
|
56 |
+
}
|
57 |
+
</style>
|
58 |
+
"""
|
59 |
+
st.markdown(main_bg_style, unsafe_allow_html=True)
|
60 |
+
|
61 |
+
# Title of the chatbot
|
62 |
+
st.markdown('<h1 style="color:#4CAF50; font-weight:bold;">🤖 Chatbot with URL-based Document Retrieval</h1>', unsafe_allow_html=True)
|
63 |
+
|
64 |
+
# Chat query input field with placeholder and help text
|
65 |
+
query = st.text_input("Ask a question based on the blog post", placeholder="Type your question here...", help="Enter a question related to the content of the blog post")
|
66 |
+
|
67 |
+
# Placeholder to display responses
|
68 |
+
if 'chat_history' not in st.session_state:
|
69 |
+
st.session_state['chat_history'] = []
|
70 |
+
|
71 |
+
# Submit button for chat
|
72 |
+
if st.button("Submit Query"):
|
73 |
+
if query and url_input:
|
74 |
+
# Blog loading logic based on user input URL
|
75 |
+
loader = WebBaseLoader(
|
76 |
+
web_paths=(url_input,), # Use the user-input URL
|
77 |
+
bs_kwargs=dict(
|
78 |
+
parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
|
79 |
+
),
|
80 |
+
)
|
81 |
+
docs = loader.load()
|
82 |
+
|
83 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
84 |
+
splits = text_splitter.split_documents(docs)
|
85 |
+
|
86 |
+
# Define the embedding class
|
87 |
+
class SentenceTransformerEmbedding:
|
88 |
+
def __init__(self, model_name):
|
89 |
+
self.model = SentenceTransformer(model_name)
|
90 |
+
|
91 |
+
def embed_documents(self, texts):
|
92 |
+
embeddings = self.model.encode(texts, convert_to_tensor=True)
|
93 |
+
if isinstance(embeddings, torch.Tensor):
|
94 |
+
return embeddings.cpu().detach().numpy().tolist() # Convert tensor to list
|
95 |
+
return embeddings
|
96 |
+
|
97 |
+
def embed_query(self, query):
|
98 |
+
embedding = self.model.encode([query], convert_to_tensor=True)
|
99 |
+
if isinstance(embedding, torch.Tensor):
|
100 |
+
return embedding.cpu().detach().numpy().tolist()[0] # Convert tensor to list
|
101 |
+
return embedding[0]
|
102 |
+
|
103 |
+
# Initialize the embedding model
|
104 |
+
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
|
105 |
+
|
106 |
+
# Initialize Chroma with the embedding class
|
107 |
+
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
|
108 |
+
|
109 |
+
# Retrieve and generate using the relevant snippets of the blog
|
110 |
+
retriever = vectorstore.as_retriever()
|
111 |
+
prompt = hub.pull("rlm/rag-prompt")
|
112 |
+
|
113 |
+
def format_docs(docs):
|
114 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
115 |
+
|
116 |
+
rag_chain = (
|
117 |
+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
118 |
+
| prompt
|
119 |
+
| ChatGroq(model="llama3-8b-8192") # Replace `llm` with an appropriate language model
|
120 |
+
| StrOutputParser()
|
121 |
+
)
|
122 |
+
|
123 |
+
# Generate the answer using the user's query
|
124 |
+
result = rag_chain.invoke(query)
|
125 |
+
|
126 |
+
# Store query and response in session for chat history
|
127 |
+
st.session_state['chat_history'].append((query, result))
|
128 |
+
|
129 |
+
# Display chat history
|
130 |
+
for q, r in st.session_state['chat_history']:
|
131 |
+
st.write(f"**User:** {q}")
|
132 |
+
st.write(f"**Bot:** {r}")
|