amiguel commited on
Commit
ba95cd5
Β·
verified Β·
1 Parent(s): 95dae9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -37
app.py CHANGED
@@ -9,8 +9,7 @@ from langchain_community.document_loaders import PyPDFLoader, TextLoader
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain.vectorstores import FAISS
12
- from langchain.retrievers import BM25Retriever
13
- from langchain.retrievers import EnsembleRetriever
14
  from langchain.schema import Document
15
  from langchain.docstore.document import Document as LangchainDocument
16
 
@@ -21,13 +20,18 @@ HF_TOKEN = st.secrets["HF_TOKEN"]
21
  st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
22
  st.title("πŸ“‚ DigiTs the Twin")
23
 
24
- # --- Upload Files Sidebar ---
25
  with st.sidebar:
26
  st.header("πŸ“„ Upload Knowledge Files")
27
  uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
28
  hybrid_toggle = st.checkbox("πŸ”€ Enable Hybrid Search", value=True)
 
29
 
30
- # --- Model Loading ---
 
 
 
 
31
  @st.cache_resource
32
  def load_model():
33
  model_id = "tiiuae/falcon-7b-instruct"
@@ -37,7 +41,7 @@ def load_model():
37
 
38
  tokenizer, model = load_model()
39
 
40
- # --- Document Processing ---
41
  def process_documents(files):
42
  documents = []
43
  for file in files:
@@ -53,55 +57,62 @@ def chunk_documents(documents):
53
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
54
  return splitter.split_documents(documents)
55
 
56
- # --- Embedding and Retrieval ---
57
  def build_retrievers(chunks):
58
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
59
  faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
60
  faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
61
-
62
  bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
63
  bm25_retriever.k = 5
 
 
64
 
65
- ensemble = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
66
- return faiss_retriever, ensemble
67
-
68
- # --- Inference ---
69
- def generate_answer(query, retriever):
70
- docs = retriever.get_relevant_documents(query)
71
- context = "\n".join([doc.page_content for doc in docs])
72
-
73
- system_prompt = (
74
- "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
75
- "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's). "
76
- "Use the context below to answer professionally.\n\nContext:\n" + context + "\n\nQuery: " + query + "\nAnswer:"
77
- )
78
-
79
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
80
  inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
81
  generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
82
-
83
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
84
  thread.start()
85
-
86
- answer = ""
87
  for token in streamer:
88
- answer += token
89
- yield answer
90
 
91
- # --- Main App ---
92
  if uploaded_files:
93
  with st.spinner("Processing documents..."):
94
  docs = process_documents(uploaded_files)
95
  chunks = chunk_documents(docs)
96
  faiss_retriever, hybrid_retriever = build_retrievers(chunks)
97
- st.success("Documents processed successfully.")
98
-
99
- query = st.text_input("πŸ” Ask a question based on the uploaded documents")
100
- if query:
101
- st.subheader("πŸ“€ Answer")
102
  retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
103
- response_placeholder = st.empty()
104
- full_response = ""
105
- for partial_response in generate_answer(query, retriever):
106
- full_response = partial_response
107
- response_placeholder.markdown(full_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain.vectorstores import FAISS
12
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
 
13
  from langchain.schema import Document
14
  from langchain.docstore.document import Document as LangchainDocument
15
 
 
20
  st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
21
  st.title("πŸ“‚ DigiTs the Twin")
22
 
23
+ # --- Sidebar ---
24
  with st.sidebar:
25
  st.header("πŸ“„ Upload Knowledge Files")
26
  uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
27
  hybrid_toggle = st.checkbox("πŸ”€ Enable Hybrid Search", value=True)
28
+ clear_chat = st.button("🧹 Clear Chat History")
29
 
30
+ # --- Session State ---
31
+ if "messages" not in st.session_state or clear_chat:
32
+ st.session_state.messages = []
33
+
34
+ # --- Load Model + Tokenizer ---
35
  @st.cache_resource
36
  def load_model():
37
  model_id = "tiiuae/falcon-7b-instruct"
 
41
 
42
  tokenizer, model = load_model()
43
 
44
+ # --- Process Documents ---
45
  def process_documents(files):
46
  documents = []
47
  for file in files:
 
57
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
58
  return splitter.split_documents(documents)
59
 
60
+ # --- Build Hybrid Retriever ---
61
  def build_retrievers(chunks):
62
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
63
  faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
64
  faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
 
65
  bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
66
  bm25_retriever.k = 5
67
+ hybrid = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
68
+ return faiss_retriever, hybrid
69
 
70
+ # --- Inference with Streaming ---
71
+ def generate_stream_response(system_prompt):
 
 
 
 
 
 
 
 
 
 
 
 
72
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
73
  inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
74
  generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
 
75
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
76
  thread.start()
77
+ partial_output = ""
 
78
  for token in streamer:
79
+ partial_output += token
80
+ yield partial_output
81
 
82
+ # --- Main App Logic ---
83
  if uploaded_files:
84
  with st.spinner("Processing documents..."):
85
  docs = process_documents(uploaded_files)
86
  chunks = chunk_documents(docs)
87
  faiss_retriever, hybrid_retriever = build_retrievers(chunks)
 
 
 
 
 
88
  retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
89
+ st.success("Knowledge base ready. Ask your question below.")
90
+
91
+ for msg in st.session_state.messages:
92
+ with st.chat_message(msg["role"]):
93
+ st.markdown(msg["content"])
94
+
95
+ user_input = st.chat_input("πŸ’¬ Ask DigiTwin something...")
96
+ if user_input:
97
+ st.chat_message("user").markdown(user_input)
98
+ st.session_state.messages.append({"role": "user", "content": user_input})
99
+
100
+ with st.chat_message("assistant"):
101
+ context_docs = retriever.get_relevant_documents(user_input)
102
+ context_text = "\n".join([doc.page_content for doc in context_docs])
103
+
104
+ system_prompt = (
105
+ "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
106
+ "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
107
+ f"Context:\n{context_text}\n\n"
108
+ f"User: {user_input}\nAssistant:"
109
+ )
110
+
111
+ full_response = ""
112
+ response_area = st.empty()
113
+ for partial_output in generate_stream_response(system_prompt):
114
+ full_response = partial_output
115
+ response_area.markdown(full_response)
116
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
117
+ else:
118
+ st.info("πŸ‘ˆ Upload one or more PDFs or .txt files to begin.")