amiguel commited on
Commit
f39631c
Β·
verified Β·
1 Parent(s): be6c945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -77
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  import torch
3
  import os
4
- import tempfile
5
  import time
6
  from threading import Thread
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
@@ -9,13 +8,7 @@ from langchain_community.document_loaders import PyPDFLoader, TextLoader
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain.vectorstores import FAISS
12
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
13
  from langchain.schema import Document
14
- from langchain.docstore.document import Document as LangchainDocument
15
-
16
- # --- Avatars ---
17
- USER_AVATAR = "πŸ‘€"
18
- BOT_AVATAR = "πŸ€–"
19
 
20
  # --- HF Token ---
21
  HF_TOKEN = st.secrets["HF_TOKEN"]
@@ -24,86 +17,94 @@ HF_TOKEN = st.secrets["HF_TOKEN"]
24
  st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
25
  st.title("πŸ“‚ DigiTs the Twin")
26
 
27
- # --- Sidebar ---
28
  with st.sidebar:
29
  st.header("πŸ“„ Upload Knowledge Files")
30
  uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
31
- hybrid_toggle = st.checkbox("πŸ”€ Enable Hybrid Search", value=True)
32
- clear_chat = st.button("🧹 Clear Chat History")
33
-
34
- # --- Session State ---
35
- if "messages" not in st.session_state or clear_chat:
36
- st.session_state.messages = []
37
 
38
- # --- Load Model ---
39
  @st.cache_resource
40
  def load_model():
41
- model_id = "tiiuae/falcon-7b-instruct"
42
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
43
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
44
- return tokenizer, model
45
-
46
- tokenizer, model = load_model()
47
-
48
- # --- Load & Chunk Documents ---
49
- def process_documents(files):
50
- documents = []
51
- for file in files:
52
- suffix = ".pdf" if file.name.endswith(".pdf") else ".txt"
53
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
54
- tmp_file.write(file.read())
55
- tmp_file_path = tmp_file.name
56
- loader = PyPDFLoader(tmp_file_path) if suffix == ".pdf" else TextLoader(tmp_file_path)
57
- documents.extend(loader.load())
58
- return documents
59
-
60
- def chunk_documents(documents):
61
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
62
- return splitter.split_documents(documents)
63
-
64
- def build_retrievers(chunks):
65
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
66
- faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
67
- faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
68
- bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
69
- bm25_retriever.k = 5
70
- return faiss_retriever, EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
71
-
72
- # --- Prompt Builder ---
73
- def build_prompt(history, context=""):
74
- conversation = ""
75
- for turn in history:
76
- role = "User" if turn["role"] == "user" else "Assistant"
77
- conversation += f"{role}: {turn['content']}\n"
78
- return (
79
- "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
80
- "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
81
- f"Context:\n{context}\n\n"
82
- f"{conversation}Assistant:"
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # --- Generator ---
86
- def generate_response(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
88
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
- generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
90
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
91
  thread.start()
92
- for token in streamer:
93
- yield token
94
-
95
- # --- Main App ---
96
- retriever = None
97
- if uploaded_files:
98
- with st.spinner("Processing documents..."):
99
- docs = process_documents(uploaded_files)
100
- chunks = chunk_documents(docs)
101
- faiss, hybrid = build_retrievers(chunks)
102
- retriever = hybrid if hybrid_toggle else faiss
103
- st.success("Documents processed. Ask away!")
104
 
105
  for msg in st.session_state.messages:
106
- with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
 
107
  st.markdown(msg["content"])
108
 
109
  # --- Chat UI ---
@@ -113,12 +114,13 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
113
 
114
  context = ""
115
  if retriever:
116
- docs = retriever.get_relevant_documents(prompt)
117
  context = "\n\n".join([d.page_content for d in docs])
118
 
119
  full_prompt = build_prompt(st.session_state.messages, context=context)
120
 
121
  with st.chat_message("assistant", avatar=BOT_AVATAR):
 
122
  streamer = generate_response(full_prompt)
123
  container = st.empty()
124
  answer = ""
@@ -126,4 +128,4 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
126
  answer += chunk
127
  container.markdown(answer + "β–Œ", unsafe_allow_html=True)
128
  container.markdown(answer)
129
- st.session_state.messages.append({"role": "assistant", "content": answer})
 
1
  import streamlit as st
2
  import torch
3
  import os
 
4
  import time
5
  from threading import Thread
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
 
11
  from langchain.schema import Document
 
 
 
 
 
12
 
13
  # --- HF Token ---
14
  HF_TOKEN = st.secrets["HF_TOKEN"]
 
17
  st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
18
  st.title("πŸ“‚ DigiTs the Twin")
19
 
20
+ # --- Upload Files Sidebar ---
21
  with st.sidebar:
22
  st.header("πŸ“„ Upload Knowledge Files")
23
  uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
24
+ if uploaded_files:
25
+ st.success(f"{len(uploaded_files)} file(s) uploaded")
 
 
 
 
26
 
27
+ # --- Model Loading ---
28
  @st.cache_resource
29
  def load_model():
30
+ tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_Qwen1.8B_Finetune", trust_remote_code=True, token=HF_TOKEN)
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ "amiguel/GM_Qwen1.8B_Finetune",
33
+ device_map="auto",
34
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
35
+ trust_remote_code=True,
36
+ token=HF_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  )
38
+ return model, tokenizer
39
+
40
+ model, tokenizer = load_model()
41
+
42
+ # --- Prompt Helper ---
43
+ SYSTEM_PROMPT = (
44
+ "You are DigiTwin, a digital expert and senior topside engineer specializing in inspection and maintenance "
45
+ "of offshore piping systems, structural elements, mechanical equipment, floating production units, pressure vessels "
46
+ "(with emphasis on Visual Internal Inspection - VII), and pressure safety devices (PSDs). Rely on uploaded documents "
47
+ "and context to provide practical, standards-driven, and technically accurate responses. Your guidance reflects deep "
48
+ "field experience, industry regulations, and proven methodologies in asset integrity and reliability engineering."
49
+ )
50
+
51
+
52
+ def build_prompt(messages, context=""):
53
+ prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
54
+ for msg in messages:
55
+ role = msg["role"]
56
+ prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
57
+ prompt += "<|im_start|>assistant\n"
58
+ return prompt
59
 
60
+
61
+ # --- RAG Embedding and Search ---
62
+ @st.cache_resource
63
+ def embed_uploaded_files(files):
64
+ raw_docs = []
65
+ for f in files:
66
+ file_path = f"/tmp/{f.name}"
67
+ with open(file_path, "wb") as out_file:
68
+ out_file.write(f.read())
69
+
70
+ loader = PyPDFLoader(file_path) if f.name.endswith(".pdf") else TextLoader(file_path)
71
+ raw_docs.extend(loader.load())
72
+
73
+ splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
74
+ chunks = splitter.split_documents(raw_docs)
75
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
76
+ db = FAISS.from_documents(chunks, embedding=embeddings)
77
+ return db
78
+
79
+ retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None
80
+
81
+ # --- Streaming Response ---
82
+ def generate_response(prompt_text):
83
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84
+ inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
85
+ thread = Thread(target=model.generate, kwargs={
86
+ "input_ids": inputs["input_ids"],
87
+ "attention_mask": inputs["attention_mask"],
88
+ "max_new_tokens": 1024,
89
+ "temperature": 0.7,
90
+ "top_p": 0.9,
91
+ "repetition_penalty": 1.1,
92
+ "do_sample": True,
93
+ "streamer": streamer
94
+ })
95
  thread.start()
96
+ return streamer
97
+
98
+ # --- Avatars & Messages ---
99
+ USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
100
+ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
101
+
102
+ if "messages" not in st.session_state:
103
+ st.session_state.messages = []
 
 
 
 
104
 
105
  for msg in st.session_state.messages:
106
+ avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
107
+ with st.chat_message(msg["role"], avatar=avatar):
108
  st.markdown(msg["content"])
109
 
110
  # --- Chat UI ---
 
114
 
115
  context = ""
116
  if retriever:
117
+ docs = retriever.similarity_search(prompt, k=3)
118
  context = "\n\n".join([d.page_content for d in docs])
119
 
120
  full_prompt = build_prompt(st.session_state.messages, context=context)
121
 
122
  with st.chat_message("assistant", avatar=BOT_AVATAR):
123
+ start_time = time.time()
124
  streamer = generate_response(full_prompt)
125
  container = st.empty()
126
  answer = ""
 
128
  answer += chunk
129
  container.markdown(answer + "β–Œ", unsafe_allow_html=True)
130
  container.markdown(answer)
131
+ st.session_state.messages.append({"role": "assistant", "content": answer})