File size: 4,946 Bytes
95dae9c 99cc3fd ba95cd5 99cc3fd 95dae9c 99cc3fd 85bf964 99cc3fd ba95cd5 99cc3fd 95dae9c ba95cd5 99cc3fd ba95cd5 99cc3fd 95dae9c ba95cd5 95dae9c ba95cd5 95dae9c ba95cd5 47a0ae1 ba95cd5 99cc3fd 95dae9c 99cc3fd ba95cd5 95dae9c ba95cd5 95dae9c ba95cd5 95dae9c ba95cd5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document
from langchain.docstore.document import Document as LangchainDocument
# --- HF Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
st.title("π DigiTs the Twin")
# --- Sidebar ---
with st.sidebar:
st.header("π Upload Knowledge Files")
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
hybrid_toggle = st.checkbox("π Enable Hybrid Search", value=True)
clear_chat = st.button("π§Ή Clear Chat History")
# --- Session State ---
if "messages" not in st.session_state or clear_chat:
st.session_state.messages = []
# --- Load Model + Tokenizer ---
@st.cache_resource
def load_model():
model_id = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
return tokenizer, model
tokenizer, model = load_model()
# --- Process Documents ---
def process_documents(files):
documents = []
for file in files:
if file.name.endswith(".pdf"):
loader = PyPDFLoader(file)
else:
loader = TextLoader(file)
docs = loader.load()
documents.extend(docs)
return documents
def chunk_documents(documents):
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
return splitter.split_documents(documents)
# --- Build Hybrid Retriever ---
def build_retrievers(chunks):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
bm25_retriever.k = 5
hybrid = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
return faiss_retriever, hybrid
# --- Inference with Streaming ---
def generate_stream_response(system_prompt):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_output = ""
for token in streamer:
partial_output += token
yield partial_output
# --- Main App Logic ---
if uploaded_files:
with st.spinner("Processing documents..."):
docs = process_documents(uploaded_files)
chunks = chunk_documents(docs)
faiss_retriever, hybrid_retriever = build_retrievers(chunks)
retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
st.success("Knowledge base ready. Ask your question below.")
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
user_input = st.chat_input("π¬ Ask DigiTwin something...")
if user_input:
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("assistant"):
context_docs = retriever.get_relevant_documents(user_input)
context_text = "\n".join([doc.page_content for doc in context_docs])
system_prompt = (
"You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
"of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
f"Context:\n{context_text}\n\n"
f"User: {user_input}\nAssistant:"
)
full_response = ""
response_area = st.empty()
for partial_output in generate_stream_response(system_prompt):
full_response = partial_output
response_area.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
else:
st.info("π Upload one or more PDFs or .txt files to begin.")
|