Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from langchain import OpenAI
|
|
6 |
from langchain.chains import RetrievalQAWithSourcesChain
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain.document_loaders import UnstructuredURLLoader
|
9 |
-
# from langchain.embeddings import OpenAIEmbeddings
|
10 |
from langchain.embeddings import FakeEmbeddings
|
11 |
from langchain.llms import HuggingFaceHub
|
12 |
from langchain.chains import LLMChain
|
@@ -16,67 +15,64 @@ from dotenv import load_dotenv
|
|
16 |
load_dotenv() # take environment variables from .env (especially openai api key)
|
17 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = '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'
|
18 |
|
19 |
-
st.title("RockyBot: News Research Tool 📈")
|
20 |
-
st.sidebar.title("News Article URLs")
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
-
|
29 |
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
chunk_size=1000
|
44 |
-
)
|
45 |
-
docs = text_splitter.split_documents(loader.load())
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
vectorstore_openai = FAISS.from_documents(docs, embeddings)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
pickle.dump(vectorstore_openai, f)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
st.progress(100.0)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
if sources:
|
74 |
-
st.subheader("Sources:")
|
75 |
-
sources_list = sources.split("\n") # Split the sources by newline
|
76 |
-
for source in sources_list:
|
77 |
-
st.write(source)
|
78 |
-
except Exception as e:
|
79 |
-
st.error(e)
|
80 |
|
81 |
if __name__ == '__main__':
|
82 |
-
st.main()
|
|
|
6 |
from langchain.chains import RetrievalQAWithSourcesChain
|
7 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
from langchain.document_loaders import UnstructuredURLLoader
|
|
|
9 |
from langchain.embeddings import FakeEmbeddings
|
10 |
from langchain.llms import HuggingFaceHub
|
11 |
from langchain.chains import LLMChain
|
|
|
15 |
load_dotenv() # take environment variables from .env (especially openai api key)
|
16 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = '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'
|
17 |
|
|
|
|
|
18 |
|
19 |
+
class RockyBot:
|
20 |
+
def __init__(self, llm):
|
21 |
+
self.llm = llm
|
22 |
+
self.vectorstore = None
|
23 |
|
24 |
+
def process_urls(self, urls):
|
25 |
+
"""Processes the given URLs and saves the FAISS index to a pickle file."""
|
26 |
|
27 |
+
# load data
|
28 |
+
loader = UnstructuredURLLoader(urls=urls)
|
29 |
|
30 |
+
# split data
|
31 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
32 |
+
separators=['\n\n', '\n', '.', ','],
|
33 |
+
chunk_size=1000
|
34 |
+
)
|
35 |
+
docs = text_splitter.split_documents(loader.load())
|
36 |
|
37 |
+
# create embeddings and save it to FAISS index
|
38 |
+
embeddings = FakeEmbeddings(size=1352)
|
39 |
+
self.vectorstore = FAISS.from_documents(docs, embeddings)
|
40 |
|
41 |
+
# Save the FAISS index to a pickle file
|
42 |
+
with open("faiss_store_openai.pkl", "wb") as f:
|
43 |
+
pickle.dump(self.vectorstore, f)
|
|
|
|
|
|
|
44 |
|
45 |
+
def answer_question(self, question):
|
46 |
+
"""Answers the given question using the LLM and retriever."""
|
|
|
47 |
|
48 |
+
chain = RetrievalQAWithSourcesChain.from_llm(llm=self.llm, retriever=self.vectorstore.as_retriever())
|
49 |
+
result = chain({"question": question}, return_only_outputs=True)
|
|
|
50 |
|
51 |
+
return result["answer"], result.get("sources", "")
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.5, "max_length": 64})
|
56 |
+
rockybot = RockyBot(llm)
|
57 |
+
|
58 |
+
# Process URLs if the button is clicked
|
59 |
+
if st.sidebar.button("Process URLs"):
|
60 |
+
rockybot.process_urls(st.sidebar.text_input("URL 1"), st.sidebar.text_input("URL 2"), st.sidebar.text_input("URL 3"))
|
61 |
st.progress(100.0)
|
62 |
|
63 |
+
# Answer the question if it is not empty
|
64 |
+
query = st.text_input("Question: ")
|
65 |
+
if query:
|
66 |
+
answer, sources = rockybot.answer_question(query)
|
67 |
+
|
68 |
+
st.header("Answer")
|
69 |
+
st.write(answer)
|
70 |
+
|
71 |
+
# Display sources, if available
|
72 |
+
if sources:
|
73 |
+
st.subheader("Sources:")
|
74 |
+
for source in sources.split("\n"):
|
75 |
+
st.write(source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
if __name__ == '__main__':
|
78 |
+
st.main()
|