Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,7 @@ from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
|
9 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain.vectorstores import FAISS
|
12 |
-
from langchain.retrievers import BM25Retriever
|
13 |
-
from langchain.retrievers import EnsembleRetriever
|
14 |
from langchain.schema import Document
|
15 |
from langchain.docstore.document import Document as LangchainDocument
|
16 |
|
@@ -21,13 +20,18 @@ HF_TOKEN = st.secrets["HF_TOKEN"]
|
|
21 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
22 |
st.title("π DigiTs the Twin")
|
23 |
|
24 |
-
# ---
|
25 |
with st.sidebar:
|
26 |
st.header("π Upload Knowledge Files")
|
27 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
28 |
hybrid_toggle = st.checkbox("π Enable Hybrid Search", value=True)
|
|
|
29 |
|
30 |
-
# ---
|
|
|
|
|
|
|
|
|
31 |
@st.cache_resource
|
32 |
def load_model():
|
33 |
model_id = "tiiuae/falcon-7b-instruct"
|
@@ -37,7 +41,7 @@ def load_model():
|
|
37 |
|
38 |
tokenizer, model = load_model()
|
39 |
|
40 |
-
# ---
|
41 |
def process_documents(files):
|
42 |
documents = []
|
43 |
for file in files:
|
@@ -53,55 +57,62 @@ def chunk_documents(documents):
|
|
53 |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
54 |
return splitter.split_documents(documents)
|
55 |
|
56 |
-
# ---
|
57 |
def build_retrievers(chunks):
|
58 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
59 |
faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
|
60 |
faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
61 |
-
|
62 |
bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
|
63 |
bm25_retriever.k = 5
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
# --- Inference ---
|
69 |
-
def generate_answer(query, retriever):
|
70 |
-
docs = retriever.get_relevant_documents(query)
|
71 |
-
context = "\n".join([doc.page_content for doc in docs])
|
72 |
-
|
73 |
-
system_prompt = (
|
74 |
-
"You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
|
75 |
-
"of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's). "
|
76 |
-
"Use the context below to answer professionally.\n\nContext:\n" + context + "\n\nQuery: " + query + "\nAnswer:"
|
77 |
-
)
|
78 |
-
|
79 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
80 |
inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
|
81 |
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
|
82 |
-
|
83 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
84 |
thread.start()
|
85 |
-
|
86 |
-
answer = ""
|
87 |
for token in streamer:
|
88 |
-
|
89 |
-
yield
|
90 |
|
91 |
-
# --- Main App ---
|
92 |
if uploaded_files:
|
93 |
with st.spinner("Processing documents..."):
|
94 |
docs = process_documents(uploaded_files)
|
95 |
chunks = chunk_documents(docs)
|
96 |
faiss_retriever, hybrid_retriever = build_retrievers(chunks)
|
97 |
-
st.success("Documents processed successfully.")
|
98 |
-
|
99 |
-
query = st.text_input("π Ask a question based on the uploaded documents")
|
100 |
-
if query:
|
101 |
-
st.subheader("π€ Answer")
|
102 |
retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain.vectorstores import FAISS
|
12 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
|
|
13 |
from langchain.schema import Document
|
14 |
from langchain.docstore.document import Document as LangchainDocument
|
15 |
|
|
|
20 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
21 |
st.title("π DigiTs the Twin")
|
22 |
|
23 |
+
# --- Sidebar ---
|
24 |
with st.sidebar:
|
25 |
st.header("π Upload Knowledge Files")
|
26 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
27 |
hybrid_toggle = st.checkbox("π Enable Hybrid Search", value=True)
|
28 |
+
clear_chat = st.button("π§Ή Clear Chat History")
|
29 |
|
30 |
+
# --- Session State ---
|
31 |
+
if "messages" not in st.session_state or clear_chat:
|
32 |
+
st.session_state.messages = []
|
33 |
+
|
34 |
+
# --- Load Model + Tokenizer ---
|
35 |
@st.cache_resource
|
36 |
def load_model():
|
37 |
model_id = "tiiuae/falcon-7b-instruct"
|
|
|
41 |
|
42 |
tokenizer, model = load_model()
|
43 |
|
44 |
+
# --- Process Documents ---
|
45 |
def process_documents(files):
|
46 |
documents = []
|
47 |
for file in files:
|
|
|
57 |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
58 |
return splitter.split_documents(documents)
|
59 |
|
60 |
+
# --- Build Hybrid Retriever ---
|
61 |
def build_retrievers(chunks):
|
62 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
63 |
faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
|
64 |
faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
|
|
65 |
bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
|
66 |
bm25_retriever.k = 5
|
67 |
+
hybrid = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
|
68 |
+
return faiss_retriever, hybrid
|
69 |
|
70 |
+
# --- Inference with Streaming ---
|
71 |
+
def generate_stream_response(system_prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
73 |
inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
|
74 |
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
|
|
|
75 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
76 |
thread.start()
|
77 |
+
partial_output = ""
|
|
|
78 |
for token in streamer:
|
79 |
+
partial_output += token
|
80 |
+
yield partial_output
|
81 |
|
82 |
+
# --- Main App Logic ---
|
83 |
if uploaded_files:
|
84 |
with st.spinner("Processing documents..."):
|
85 |
docs = process_documents(uploaded_files)
|
86 |
chunks = chunk_documents(docs)
|
87 |
faiss_retriever, hybrid_retriever = build_retrievers(chunks)
|
|
|
|
|
|
|
|
|
|
|
88 |
retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
|
89 |
+
st.success("Knowledge base ready. Ask your question below.")
|
90 |
+
|
91 |
+
for msg in st.session_state.messages:
|
92 |
+
with st.chat_message(msg["role"]):
|
93 |
+
st.markdown(msg["content"])
|
94 |
+
|
95 |
+
user_input = st.chat_input("π¬ Ask DigiTwin something...")
|
96 |
+
if user_input:
|
97 |
+
st.chat_message("user").markdown(user_input)
|
98 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
99 |
+
|
100 |
+
with st.chat_message("assistant"):
|
101 |
+
context_docs = retriever.get_relevant_documents(user_input)
|
102 |
+
context_text = "\n".join([doc.page_content for doc in context_docs])
|
103 |
+
|
104 |
+
system_prompt = (
|
105 |
+
"You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
|
106 |
+
"of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
|
107 |
+
f"Context:\n{context_text}\n\n"
|
108 |
+
f"User: {user_input}\nAssistant:"
|
109 |
+
)
|
110 |
+
|
111 |
+
full_response = ""
|
112 |
+
response_area = st.empty()
|
113 |
+
for partial_output in generate_stream_response(system_prompt):
|
114 |
+
full_response = partial_output
|
115 |
+
response_area.markdown(full_response)
|
116 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
117 |
+
else:
|
118 |
+
st.info("π Upload one or more PDFs or .txt files to begin.")
|