vakodiya commited on
Commit
e4d6763
Β·
verified Β·
1 Parent(s): fdbe98b

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +103 -0
main.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pickle
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from langchain.llms.base import LLM
7
+ from langchain.chains import RetrievalQAWithSourcesChain
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.document_loaders import UnstructuredURLLoader
10
+ from langchain.vectorstores import FAISS
11
+ from secret_key import google_genai_api_key
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+
14
+
15
+ class CustomHuggingFaceLLM(LLM):
16
+ def __init__(self, model_name, temperature=0.7):
17
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ self.temperature = temperature
20
+
21
+ def _call(self, prompt, stop=None):
22
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
23
+ output = self.model.generate(
24
+ input_ids,
25
+ max_length=512,
26
+ temperature=self.temperature,
27
+ do_sample=True,
28
+ top_p=0.95,
29
+ top_k=3
30
+ )
31
+ generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
32
+ return generated_text
33
+
34
+ @property
35
+ def _identifying_params(self):
36
+ return {"model_name": self.model.config._name_or_path, "temperature": self.temperature}
37
+
38
+ @property
39
+ def _llm_type(self):
40
+ return "custom_huggingface"
41
+
42
+
43
+
44
+ main_directory = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ st.title("Web Page search Bot: Research Tool πŸ“ˆ")
47
+ st.sidebar.title("Article URLs")
48
+
49
+ urls = []
50
+ for i in range(3):
51
+ url = st.sidebar.text_input(f"URL {i+1}")
52
+ urls.append(url)
53
+
54
+ process_url_clicked = st.sidebar.button("Process URLs")
55
+ file_path_faiss = "faiss_store.pkl"
56
+
57
+ main_placeholder = st.empty()
58
+
59
+ # Load a pre-trained embedding model
60
+ embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
61
+ llm = CustomHuggingFaceLLM(model_name="meta-llama/Meta-Llama-3.1-8B", temperature=0.6)
62
+ if process_url_clicked:
63
+ # load data
64
+ loader = UnstructuredURLLoader(urls=urls)
65
+ main_placeholder.text("Data Loading...Started...βœ…βœ…βœ…")
66
+ data = loader.load()
67
+ # split data
68
+ # Do not include unnecessary separators like , and . It will reduce chunks too small.
69
+ text_splitter = RecursiveCharacterTextSplitter(
70
+ separators=['\n\n'],
71
+ chunk_size=1000,
72
+ chunk_overlap=100
73
+ )
74
+ main_placeholder.text("Text Splitter...Started...βœ…βœ…βœ…")
75
+ docs = text_splitter.split_documents(data)
76
+
77
+ # create embeddings and save it to FAISS index
78
+ vectorstore_faiss = FAISS.from_documents(documents=docs,embedding=embedding_model)
79
+ main_placeholder.text("Embedding Vector Started Building...βœ…βœ…βœ…")
80
+ time.sleep(2)
81
+
82
+ # Save the FAISS index to a pickle file
83
+ with open(file_path_faiss, "wb") as f:
84
+ pickle.dump(vectorstore_faiss, f)
85
+
86
+ query = main_placeholder.text_input("Question: ")
87
+ if query:
88
+ if os.path.exists(file_path_faiss):
89
+ with open(file_path_faiss, "rb") as f:
90
+ vectorstore = pickle.load(f)
91
+ chain = RetrievalQAWithSourcesChain.from_llm(llm=llm, retriever=vectorstore.as_retriever(), verbose=True) # type: ignore
92
+ result = chain({"question": query}, return_only_outputs=True)
93
+ # result will be a dictionary of this format --> {"answer": "", "sources": [] }
94
+ st.header("Answer")
95
+ st.write(result["answer"])
96
+
97
+ # Display sources, if available
98
+ sources = result.get("sources", "")
99
+ if sources:
100
+ st.subheader("Sources:")
101
+ sources_list = sources.split("\n") # Split the sources by newline
102
+ for source in sources_list:
103
+ st.write(source)