AreesaAshfaq commited on
Commit
0a6ed15
·
verified ·
1 Parent(s): bf56332

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from langchain import hub
4
+ from langchain_chroma import Chroma
5
+ from langchain_community.document_loaders import WebBaseLoader
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ import bs4
10
+ import torch
11
+
12
+ # Define the embedding class
13
+ class SentenceTransformerEmbedding:
14
+ def __init__(self, model_name):
15
+ self.model = SentenceTransformer(model_name)
16
+
17
+ def embed_documents(self, texts):
18
+ embeddings = self.model.encode(texts, convert_to_tensor=True)
19
+ if isinstance(embeddings, torch.Tensor):
20
+ return embeddings.cpu().detach().numpy().tolist() # Convert tensor to list
21
+ return embeddings
22
+
23
+ def embed_query(self, query):
24
+ embedding = self.model.encode([query], convert_to_tensor=True)
25
+ if isinstance(embedding, torch.Tensor):
26
+ return embedding.cpu().detach().numpy().tolist()[0] # Convert tensor to list
27
+ return embedding[0]
28
+
29
+ # Initialize the embedding class
30
+ embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2')
31
+
32
+ # Load, chunk, and index the contents of the blog
33
+ def load_data():
34
+ loader = WebBaseLoader(
35
+ web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
36
+ bs_kwargs=dict(
37
+ parse_only=bs4.SoupStrainer(
38
+ class_=("post-content", "post-title", "post-header")
39
+ )
40
+ ),
41
+ )
42
+ docs = loader.load()
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
44
+ splits = text_splitter.split_documents(docs)
45
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model)
46
+ return vectorstore
47
+
48
+ vectorstore = load_data()
49
+
50
+ # Streamlit UI
51
+ st.title("Blog Retrieval and Question Answering")
52
+
53
+ question = st.text_input("Enter your question:")
54
+
55
+ if question:
56
+ retriever = vectorstore.as_retriever()
57
+ prompt = hub.pull("rlm/rag-prompt")
58
+
59
+ def format_docs(docs):
60
+ return "\n\n".join(doc.page_content for doc in docs)
61
+
62
+ rag_chain = (
63
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
64
+ | prompt
65
+ | lambda x: x # Replace with your LLM or appropriate function if needed
66
+ | StrOutputParser()
67
+ )
68
+
69
+ # Example invocation
70
+ try:
71
+ result = rag_chain.invoke(question)
72
+ st.write("Answer:", result)
73
+ except Exception as e:
74
+ st.error(f"An error occurred: {e}")