datascientist22 commited on
Commit
0953464
·
verified ·
1 Parent(s): a9e7e55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -8,6 +8,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from sentence_transformers import SentenceTransformer
9
  import bs4
10
  import torch
 
11
 
12
  # Define the embedding class
13
  class SentenceTransformerEmbedding:
@@ -106,16 +107,18 @@ query = st.text_input("Ask a question based on the blog post", placeholder="Type
106
  if 'chat_history' not in st.session_state:
107
  st.session_state['chat_history'] = []
108
 
109
- # CustomLanguageModel class with proper context argument
110
  class CustomLanguageModel:
 
 
 
111
  def generate(self, prompt, context):
112
- # Implement logic to generate a response based on prompt and context
113
- return f"Generated response: '{prompt}'. Key points from the context: '{self.summarize_context(context)}'."
114
 
115
  def summarize_context(self, context):
116
- # Summarize the context to extract key information
117
- # You could use an NLP summarization model for a more sophisticated approach
118
- return " ".join(context.split()[:100]) # Returning the first 100 words as a simple summary
119
 
120
  # Define a callable class for RAGPrompt
121
  class RAGPrompt:
@@ -136,46 +139,43 @@ if st.button("Submit Query"):
136
  parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
137
  ),
138
  )
139
- try:
140
- docs = loader.load()
141
 
142
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
143
- splits = text_splitter.split_documents(docs)
144
 
145
- # Initialize the embedding model
146
- embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
147
 
148
- # Initialize Chroma with the embedding class
149
- vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
150
 
151
- # Retrieve and generate using the relevant snippets of the blog
152
- retriever = vectorstore.as_retriever()
153
 
154
- # Retrieve relevant documents
155
- retrieved_docs = retriever.get_relevant_documents(query)
156
 
157
- # Format the retrieved documents
158
- def format_docs(docs):
159
- return "\n\n".join(doc.page_content for doc in docs)
160
 
161
- context = format_docs(retrieved_docs)
162
 
163
- # Initialize the language model
164
- custom_llm = CustomLanguageModel()
165
 
166
- # Initialize RAG chain using the prompt
167
- prompt = RAGPrompt()
168
 
169
- # Apply the prompt directly to the data (no chaining using `|`)
170
- prompt_data = prompt({"question": query, "context": context})
171
 
172
- # Generate the response using the language model, focusing on the answer from the retrieved context
173
- result = custom_llm.generate(prompt_data["question"], prompt_data["context"])
174
 
175
- # Store query and response in session for chat history
176
- st.session_state['chat_history'].append((query, result))
177
- except Exception as e:
178
- st.error(f"Error loading the blog or processing the query: {e}")
179
 
180
  # Display chat history
181
  for q, r in st.session_state['chat_history']:
 
8
  from sentence_transformers import SentenceTransformer
9
  import bs4
10
  import torch
11
+ from transformers import pipeline
12
 
13
  # Define the embedding class
14
  class SentenceTransformerEmbedding:
 
107
  if 'chat_history' not in st.session_state:
108
  st.session_state['chat_history'] = []
109
 
110
+ # CustomLanguageModel class with summarization
111
  class CustomLanguageModel:
112
+ def __init__(self):
113
+ self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Replace with desired model
114
+
115
  def generate(self, prompt, context):
116
+ summary = self.summarize_context(context)
117
+ return f"Generated response: '{prompt}'. Summary: '{summary}'."
118
 
119
  def summarize_context(self, context):
120
+ summarized = self.summarizer(context, max_length=200, min_length=100, do_sample=False)
121
+ return summarized[0]['summary_text'] # Ensure it outputs full, meaningful sentences
 
122
 
123
  # Define a callable class for RAGPrompt
124
  class RAGPrompt:
 
139
  parse_only=bs4.SoupStrainer() # Adjust based on the user's URL structure
140
  ),
141
  )
142
+ docs = loader.load()
 
143
 
144
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=300)
145
+ splits = text_splitter.split_documents(docs)
146
 
147
+ # Initialize the embedding model
148
+ embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
149
 
150
+ # Initialize Chroma with the embedding class
151
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
152
 
153
+ # Retrieve and generate using the relevant snippets of the blog
154
+ retriever = vectorstore.as_retriever()
155
 
156
+ # Retrieve relevant documents
157
+ retrieved_docs = retriever.get_relevant_documents(query)
158
 
159
+ # Format the retrieved documents
160
+ def format_docs(docs):
161
+ return "\n\n".join(doc.page_content for doc in docs)
162
 
163
+ context = format_docs(retrieved_docs)
164
 
165
+ # Initialize the language model
166
+ custom_llm = CustomLanguageModel()
167
 
168
+ # Initialize RAG chain using the prompt
169
+ prompt = RAGPrompt()
170
 
171
+ # Apply the prompt directly to the data (no chaining using `|`)
172
+ prompt_data = prompt({"question": query, "context": context})
173
 
174
+ # Generate the response using the language model, focusing on the answer from the retrieved context
175
+ result = custom_llm.generate(prompt_data["question"], prompt_data["context"])
176
 
177
+ # Store query and response in session for chat history
178
+ st.session_state['chat_history'].append((query, result))
 
 
179
 
180
  # Display chat history
181
  for q, r in st.session_state['chat_history']: