File size: 5,243 Bytes
99cc3fd
 
 
2db00ad
be6c945
99cc3fd
 
 
 
 
 
ba95cd5
99cc3fd
95dae9c
99cc3fd
be6c945
 
 
 
99cc3fd
 
 
 
 
85bf964
99cc3fd
ba95cd5
99cc3fd
 
 
95dae9c
ba95cd5
99cc3fd
ba95cd5
 
 
 
be6c945
99cc3fd
 
2db00ad
95dae9c
 
 
 
 
 
be6c945
95dae9c
 
 
2db00ad
 
 
 
be6c945
 
95dae9c
 
 
 
 
 
 
 
 
 
 
 
be6c945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99cc3fd
be6c945
95dae9c
 
99cc3fd
95dae9c
be6c945
95dae9c
be6c945
 
95dae9c
 
 
 
be6c945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import torch
import os
import tempfile
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document
from langchain.docstore.document import Document as LangchainDocument

# --- Avatars ---
USER_AVATAR = "πŸ‘€"
BOT_AVATAR = "πŸ€–"

# --- HF Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]

# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTs the Twin")

# --- Sidebar ---
with st.sidebar:
    st.header("πŸ“„ Upload Knowledge Files")
    uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
    hybrid_toggle = st.checkbox("πŸ”€ Enable Hybrid Search", value=True)
    clear_chat = st.button("🧹 Clear Chat History")

# --- Session State ---
if "messages" not in st.session_state or clear_chat:
    st.session_state.messages = []

# --- Load Model ---
@st.cache_resource
def load_model():
    model_id = "tiiuae/falcon-7b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
    return tokenizer, model

tokenizer, model = load_model()

# --- Load & Chunk Documents ---
def process_documents(files):
    documents = []
    for file in files:
        suffix = ".pdf" if file.name.endswith(".pdf") else ".txt"
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
            tmp_file.write(file.read())
            tmp_file_path = tmp_file.name
        loader = PyPDFLoader(tmp_file_path) if suffix == ".pdf" else TextLoader(tmp_file_path)
        documents.extend(loader.load())
    return documents

def chunk_documents(documents):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    return splitter.split_documents(documents)

def build_retrievers(chunks):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
    faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
    bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
    bm25_retriever.k = 5
    return faiss_retriever, EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])

# --- Prompt Builder ---
def build_prompt(history, context=""):
    conversation = ""
    for turn in history:
        role = "User" if turn["role"] == "user" else "Assistant"
        conversation += f"{role}: {turn['content']}\n"
    return (
        "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
        "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's).\n\n"
        f"Context:\n{context}\n\n"
        f"{conversation}Assistant:"
    )

# --- Generator ---
def generate_response(prompt):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    for token in streamer:
        yield token

# --- Main App ---
retriever = None
if uploaded_files:
    with st.spinner("Processing documents..."):
        docs = process_documents(uploaded_files)
        chunks = chunk_documents(docs)
        faiss, hybrid = build_retrievers(chunks)
        retriever = hybrid if hybrid_toggle else faiss
        st.success("Documents processed. Ask away!")

for msg in st.session_state.messages:
    with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
        st.markdown(msg["content"])

# --- Chat UI ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
    st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    context = ""
    if retriever:
        docs = retriever.get_relevant_documents(prompt)
        context = "\n\n".join([d.page_content for d in docs])

    full_prompt = build_prompt(st.session_state.messages, context=context)

    with st.chat_message("assistant", avatar=BOT_AVATAR):
        streamer = generate_response(full_prompt)
        container = st.empty()
        answer = ""
        for chunk in streamer:
            answer += chunk
            container.markdown(answer + "β–Œ", unsafe_allow_html=True)
        container.markdown(answer)
        st.session_state.messages.append({"role": "assistant", "content": answer})