amiguel commited on
Commit
be6c945
Β·
verified Β·
1 Parent(s): d01bbfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -53
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import streamlit as st
2
  import torch
3
  import os
4
- import time
5
  import tempfile
 
6
  from threading import Thread
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
@@ -13,6 +13,10 @@ from langchain.retrievers import BM25Retriever, EnsembleRetriever
13
  from langchain.schema import Document
14
  from langchain.docstore.document import Document as LangchainDocument
15
 
 
 
 
 
16
  # --- HF Token ---
17
  HF_TOKEN = st.secrets["HF_TOKEN"]
18
 
@@ -31,7 +35,7 @@ with st.sidebar:
31
  if "messages" not in st.session_state or clear_chat:
32
  st.session_state.messages = []
33
 
34
- # --- Load Model + Tokenizer ---
35
  @st.cache_resource
36
  def load_model():
37
  model_id = "tiiuae/falcon-7b-instruct"
@@ -41,7 +45,7 @@ def load_model():
41
 
42
  tokenizer, model = load_model()
43
 
44
- # --- Process Documents ---
45
  def process_documents(files):
46
  documents = []
47
  for file in files:
@@ -49,75 +53,77 @@ def process_documents(files):
49
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
50
  tmp_file.write(file.read())
51
  tmp_file_path = tmp_file.name
52
-
53
- if suffix == ".pdf":
54
- loader = PyPDFLoader(tmp_file_path)
55
- else:
56
- loader = TextLoader(tmp_file_path)
57
- docs = loader.load()
58
- documents.extend(docs)
59
  return documents
60
 
61
  def chunk_documents(documents):
62
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
63
  return splitter.split_documents(documents)
64
 
65
- # --- Build Hybrid Retriever ---
66
  def build_retrievers(chunks):
67
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
68
  faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
69
  faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
70
  bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
71
  bm25_retriever.k = 5
72
- hybrid = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
73
- return faiss_retriever, hybrid
74
-
75
- # --- Inference with Streaming ---
76
- def generate_stream_response(system_prompt):
 
 
 
 
 
 
 
 
 
 
 
 
77
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
- inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
79
  generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
80
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
81
  thread.start()
82
- partial_output = ""
83
  for token in streamer:
84
- partial_output += token
85
- yield partial_output
86
 
87
- # --- Main App Logic ---
 
88
  if uploaded_files:
89
  with st.spinner("Processing documents..."):
90
  docs = process_documents(uploaded_files)
91
  chunks = chunk_documents(docs)
92
- faiss_retriever, hybrid_retriever = build_retrievers(chunks)
93
- retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
94
- st.success("Knowledge base ready. Ask your question below.")
95
-
96
- for msg in st.session_state.messages:
97
- with st.chat_message(msg["role"]):
98
- st.markdown(msg["content"])
99
-
100
- user_input = st.chat_input("πŸ’¬ Ask DigiTwin something...")
101
- if user_input:
102
- st.chat_message("user").markdown(user_input)
103
- st.session_state.messages.append({"role": "user", "content": user_input})
104
-
105
- with st.chat_message("assistant"):
106
- context_docs = retriever.get_relevant_documents(user_input)
107
- context_text = "\n".join([doc.page_content for doc in context_docs])
108
-
109
- system_prompt = (
110
- "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
111
- "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
112
- f"Context:\n{context_text}\n\n"
113
- f"User: {user_input}\nAssistant:"
114
- )
115
-
116
- full_response = ""
117
- response_area = st.empty()
118
- for partial_output in generate_stream_response(system_prompt):
119
- full_response = partial_output
120
- response_area.markdown(full_response)
121
- st.session_state.messages.append({"role": "assistant", "content": full_response})
122
- else:
123
- st.info("πŸ‘ˆ Upload one or more PDFs or .txt files to begin.")
 
1
  import streamlit as st
2
  import torch
3
  import os
 
4
  import tempfile
5
+ import time
6
  from threading import Thread
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
 
13
  from langchain.schema import Document
14
  from langchain.docstore.document import Document as LangchainDocument
15
 
16
+ # --- Avatars ---
17
+ USER_AVATAR = "πŸ‘€"
18
+ BOT_AVATAR = "πŸ€–"
19
+
20
  # --- HF Token ---
21
  HF_TOKEN = st.secrets["HF_TOKEN"]
22
 
 
35
  if "messages" not in st.session_state or clear_chat:
36
  st.session_state.messages = []
37
 
38
+ # --- Load Model ---
39
  @st.cache_resource
40
  def load_model():
41
  model_id = "tiiuae/falcon-7b-instruct"
 
45
 
46
  tokenizer, model = load_model()
47
 
48
+ # --- Load & Chunk Documents ---
49
  def process_documents(files):
50
  documents = []
51
  for file in files:
 
53
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
54
  tmp_file.write(file.read())
55
  tmp_file_path = tmp_file.name
56
+ loader = PyPDFLoader(tmp_file_path) if suffix == ".pdf" else TextLoader(tmp_file_path)
57
+ documents.extend(loader.load())
 
 
 
 
 
58
  return documents
59
 
60
  def chunk_documents(documents):
61
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
62
  return splitter.split_documents(documents)
63
 
 
64
  def build_retrievers(chunks):
65
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
66
  faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
67
  faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
68
  bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
69
  bm25_retriever.k = 5
70
+ return faiss_retriever, EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
71
+
72
+ # --- Prompt Builder ---
73
+ def build_prompt(history, context=""):
74
+ conversation = ""
75
+ for turn in history:
76
+ role = "User" if turn["role"] == "user" else "Assistant"
77
+ conversation += f"{role}: {turn['content']}\n"
78
+ return (
79
+ "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
80
+ "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
81
+ f"Context:\n{context}\n\n"
82
+ f"{conversation}Assistant:"
83
+ )
84
+
85
+ # --- Generator ---
86
+ def generate_response(prompt):
87
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
88
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
  generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
90
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
91
  thread.start()
 
92
  for token in streamer:
93
+ yield token
 
94
 
95
+ # --- Main App ---
96
+ retriever = None
97
  if uploaded_files:
98
  with st.spinner("Processing documents..."):
99
  docs = process_documents(uploaded_files)
100
  chunks = chunk_documents(docs)
101
+ faiss, hybrid = build_retrievers(chunks)
102
+ retriever = hybrid if hybrid_toggle else faiss
103
+ st.success("Documents processed. Ask away!")
104
+
105
+ for msg in st.session_state.messages:
106
+ with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
107
+ st.markdown(msg["content"])
108
+
109
+ # --- Chat UI ---
110
+ if prompt := st.chat_input("Ask something based on uploaded documents..."):
111
+ st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
112
+ st.session_state.messages.append({"role": "user", "content": prompt})
113
+
114
+ context = ""
115
+ if retriever:
116
+ docs = retriever.get_relevant_documents(prompt)
117
+ context = "\n\n".join([d.page_content for d in docs])
118
+
119
+ full_prompt = build_prompt(st.session_state.messages, context=context)
120
+
121
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
122
+ streamer = generate_response(full_prompt)
123
+ container = st.empty()
124
+ answer = ""
125
+ for chunk in streamer:
126
+ answer += chunk
127
+ container.markdown(answer + "β–Œ", unsafe_allow_html=True)
128
+ container.markdown(answer)
129
+ st.session_state.messages.append({"role": "assistant", "content": answer})