Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
import bs4
|
10 |
import torch
|
|
|
11 |
|
12 |
# Define the embedding class
|
13 |
class SentenceTransformerEmbedding:
|
@@ -106,16 +107,18 @@ query = st.text_input("Ask a question based on the blog post", placeholder="Type
|
|
106 |
if 'chat_history' not in st.session_state:
|
107 |
st.session_state['chat_history'] = []
|
108 |
|
109 |
-
# CustomLanguageModel class with
|
110 |
class CustomLanguageModel:
|
|
|
|
|
|
|
111 |
def generate(self, prompt, context):
|
112 |
-
|
113 |
-
return f"Generated response: '{prompt}'.
|
114 |
|
115 |
def summarize_context(self, context):
|
116 |
-
|
117 |
-
#
|
118 |
-
return " ".join(context.split()[:100]) # Returning the first 100 words as a simple summary
|
119 |
|
120 |
# Define a callable class for RAGPrompt
|
121 |
class RAGPrompt:
|
@@ -136,46 +139,43 @@ if st.button("Submit Query"):
|
|
136 |
parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
|
137 |
),
|
138 |
)
|
139 |
-
|
140 |
-
docs = loader.load()
|
141 |
|
142 |
-
|
143 |
-
|
144 |
|
145 |
-
|
146 |
-
|
147 |
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
|
153 |
|
154 |
-
|
155 |
-
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
|
163 |
-
|
164 |
-
|
165 |
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
|
170 |
-
|
171 |
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
except Exception as e:
|
178 |
-
st.error(f"Error loading the blog or processing the query: {e}")
|
179 |
|
180 |
# Display chat history
|
181 |
for q, r in st.session_state['chat_history']:
|
|
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
import bs4
|
10 |
import torch
|
11 |
+
from transformers import pipeline
|
12 |
|
13 |
# Define the embedding class
|
14 |
class SentenceTransformerEmbedding:
|
|
|
107 |
if 'chat_history' not in st.session_state:
|
108 |
st.session_state['chat_history'] = []
|
109 |
|
110 |
+
# CustomLanguageModel class with summarization
|
111 |
class CustomLanguageModel:
|
112 |
+
def __init__(self):
|
113 |
+
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Replace with desired model
|
114 |
+
|
115 |
def generate(self, prompt, context):
|
116 |
+
summary = self.summarize_context(context)
|
117 |
+
return f"Generated response: '{prompt}'. Summary: '{summary}'."
|
118 |
|
119 |
def summarize_context(self, context):
|
120 |
+
summarized = self.summarizer(context, max_length=200, min_length=100, do_sample=False)
|
121 |
+
return summarized[0]['summary_text'] # Ensure it outputs full, meaningful sentences
|
|
|
122 |
|
123 |
# Define a callable class for RAGPrompt
|
124 |
class RAGPrompt:
|
|
|
139 |
parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
|
140 |
),
|
141 |
)
|
142 |
+
docs = loader.load()
|
|
|
143 |
|
144 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=300)
|
145 |
+
splits = text_splitter.split_documents(docs)
|
146 |
|
147 |
+
# Initialize the embedding model
|
148 |
+
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
|
149 |
|
150 |
+
# Initialize Chroma with the embedding class
|
151 |
+
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
|
152 |
|
153 |
+
# Retrieve and generate using the relevant snippets of the blog
|
154 |
+
retriever = vectorstore.as_retriever()
|
155 |
|
156 |
+
# Retrieve relevant documents
|
157 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
158 |
|
159 |
+
# Format the retrieved documents
|
160 |
+
def format_docs(docs):
|
161 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
162 |
|
163 |
+
context = format_docs(retrieved_docs)
|
164 |
|
165 |
+
# Initialize the language model
|
166 |
+
custom_llm = CustomLanguageModel()
|
167 |
|
168 |
+
# Initialize RAG chain using the prompt
|
169 |
+
prompt = RAGPrompt()
|
170 |
|
171 |
+
# Apply the prompt directly to the data (no chaining using `|`)
|
172 |
+
prompt_data = prompt({"question": query, "context": context})
|
173 |
|
174 |
+
# Generate the response using the language model, focusing on the answer from the retrieved context
|
175 |
+
result = custom_llm.generate(prompt_data["question"], prompt_data["context"])
|
176 |
|
177 |
+
# Store query and response in session for chat history
|
178 |
+
st.session_state['chat_history'].append((query, result))
|
|
|
|
|
179 |
|
180 |
# Display chat history
|
181 |
for q, r in st.session_state['chat_history']:
|