datascientist22 commited on
Commit
479c15b
·
verified ·
1 Parent(s): fc71a0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -51
app.py CHANGED
@@ -2,16 +2,35 @@ import streamlit as st
2
  import re
3
  import os
4
  from langchain.chains import ConversationalRetrievalChain
5
- from langchain.document_loaders import WebBaseLoader
6
- from langchain.vectorstores import Chroma
7
- from langchain.prompts import load_prompt
8
- from langchain.chat_models import ChatGroq
9
- from langchain.output_parsers import StrOutputParser
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain.runnables import RunnablePassthrough
12
- import torch
13
  from sentence_transformers import SentenceTransformer
 
14
  import bs4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Sidebar Style with Multicolored Background
17
  sidebar_bg_style = """
@@ -92,49 +111,53 @@ if 'chat_history' not in st.session_state:
92
 
93
  # Submit button for chat
94
  if st.button("Submit Query"):
95
- if query:
96
- if url_input:
97
- # Blog loading logic based on user input URL
98
- loader = WebBaseLoader(
99
- web_paths=(url_input,), # Use the user-input URL
100
- bs_kwargs=dict(
101
- parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
102
- ),
103
- )
104
- docs = loader.load()
105
-
106
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
107
- splits = text_splitter.split_documents(docs)
108
-
109
- # Initialize the embedding model
110
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
111
-
112
- # Initialize Chroma with the embedding class
113
- vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
114
-
115
- # Retrieve and generate using the relevant snippets of the blog
116
- retriever = vectorstore.as_retriever()
117
- prompt = load_prompt("rlm/rag-prompt")
118
-
119
- def format_docs(docs):
120
- return "\n\n".join(doc.page_content for doc in docs)
121
-
122
- rag_chain = (
123
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
124
- | prompt
125
- | ChatGroq(model="llama3-8b-8192") # Replace `llm` with an appropriate language model
126
- | StrOutputParser()
127
- )
128
-
129
- # Generate the answer using the user's query
130
- result = rag_chain.invoke(query)
131
-
132
- # Store query and response in session for chat history
133
- st.session_state['chat_history'].append((query, result))
134
- else:
135
- st.warning("Please enter a valid URL.")
136
- else:
137
- st.warning("Please enter a question.")
 
 
 
 
138
 
139
  # Display chat history
140
  for q, r in st.session_state['chat_history']:
 
2
  import re
3
  import os
4
  from langchain.chains import ConversationalRetrievalChain
5
+ from langchain_chroma import Chroma
6
+ from langchain_community.document_loaders import WebBaseLoader
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
10
  from sentence_transformers import SentenceTransformer
11
+ from langchain import hub
12
  import bs4
13
+ import torch
14
+
15
+ # Define the embedding class
16
+ class SentenceTransformerEmbedding:
17
+ def __init__(self, model_name):
18
+ self.model = SentenceTransformer(model_name)
19
+
20
+ def embed_documents(self, texts):
21
+ embeddings = self.model.encode(texts, convert_to_tensor=True)
22
+ if isinstance(embeddings, torch.Tensor):
23
+ return embeddings.cpu().detach().numpy().tolist() # Convert tensor to list
24
+ return embeddings
25
+
26
+ def embed_query(self, query):
27
+ embedding = self.model.encode([query], convert_to_tensor=True)
28
+ if isinstance(embedding, torch.Tensor):
29
+ return embedding.cpu().detach().numpy().tolist()[0] # Convert tensor to list
30
+ return embedding[0]
31
+
32
+ # Streamlit UI setup
33
+ st.title("🤖 Chatbot with URL-based Document Retrieval")
34
 
35
  # Sidebar Style with Multicolored Background
36
  sidebar_bg_style = """
 
111
 
112
  # Submit button for chat
113
  if st.button("Submit Query"):
114
+ if query and url_input:
115
+ # Blog loading logic based on user input URL
116
+ loader = WebBaseLoader(
117
+ web_paths=(url_input,), # Use the user-input URL
118
+ bs_kwargs=dict(
119
+ parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
120
+ ),
121
+ )
122
+ docs = loader.load()
123
+
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
125
+ splits = text_splitter.split_documents(docs)
126
+
127
+ # Initialize the embedding model
128
+ embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
129
+
130
+ # Initialize Chroma with the embedding class
131
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
132
+
133
+ # Retrieve and generate using the relevant snippets of the blog
134
+ retriever = vectorstore.as_retriever()
135
+ prompt = hub.pull("rlm/rag-prompt")
136
+
137
+ def format_docs(docs):
138
+ return "\n\n".join(doc.page_content for doc in docs)
139
+
140
+ # Replace llm with an appropriate model or implement your logic
141
+ class CustomLanguageModel:
142
+ def generate(self, prompt, context):
143
+ # Custom implementation or call to an API
144
+ # For demonstration, let's use a simple placeholder response
145
+ return f"Response to query '{prompt}' based on context."
146
+
147
+ custom_llm = CustomLanguageModel()
148
+
149
+ rag_chain = (
150
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
151
+ | prompt
152
+ | custom_llm.generate # Adjust based on actual usage
153
+ | StrOutputParser()
154
+ )
155
+
156
+ # Generate the answer using the user's query
157
+ result = rag_chain.invoke(query)
158
+
159
+ # Store query and response in session for chat history
160
+ st.session_state['chat_history'].append((query, result))
161
 
162
  # Display chat history
163
  for q, r in st.session_state['chat_history']: