File size: 4,273 Bytes
95dae9c
99cc3fd
 
 
 
 
 
 
 
 
 
95dae9c
 
99cc3fd
95dae9c
99cc3fd
 
 
 
 
 
85bf964
99cc3fd
 
 
 
 
95dae9c
99cc3fd
 
 
 
95dae9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99cc3fd
95dae9c
 
47a0ae1
95dae9c
 
99cc3fd
95dae9c
 
 
 
6c5f444
95dae9c
 
 
 
 
99cc3fd
 
95dae9c
 
 
 
99cc3fd
95dae9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.schema import Document
from langchain.docstore.document import Document as LangchainDocument

# --- HF Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]

# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTs the Twin")

# --- Upload Files Sidebar ---
with st.sidebar:
    st.header("πŸ“„ Upload Knowledge Files")
    uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
    hybrid_toggle = st.checkbox("πŸ”€ Enable Hybrid Search", value=True)

# --- Model Loading ---
@st.cache_resource
def load_model():
    model_id = "tiiuae/falcon-7b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
    return tokenizer, model

tokenizer, model = load_model()

# --- Document Processing ---
def process_documents(files):
    documents = []
    for file in files:
        if file.name.endswith(".pdf"):
            loader = PyPDFLoader(file)
        else:
            loader = TextLoader(file)
        docs = loader.load()
        documents.extend(docs)
    return documents

def chunk_documents(documents):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    return splitter.split_documents(documents)

# --- Embedding and Retrieval ---
def build_retrievers(chunks):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    faiss_vectorstore = FAISS.from_documents(chunks, embeddings)
    faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})

    bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks])
    bm25_retriever.k = 5

    ensemble = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
    return faiss_retriever, ensemble

# --- Inference ---
def generate_answer(query, retriever):
    docs = retriever.get_relevant_documents(query)
    context = "\n".join([doc.page_content for doc in docs])

    system_prompt = (
        "You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance "
        "of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's). "
        "Use the context below to answer professionally.\n\nContext:\n" + context + "\n\nQuery: " + query + "\nAnswer:"
    )

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device)
    generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300)

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    answer = ""
    for token in streamer:
        answer += token
        yield answer

# --- Main App ---
if uploaded_files:
    with st.spinner("Processing documents..."):
        docs = process_documents(uploaded_files)
        chunks = chunk_documents(docs)
        faiss_retriever, hybrid_retriever = build_retrievers(chunks)
        st.success("Documents processed successfully.")

    query = st.text_input("πŸ” Ask a question based on the uploaded documents")
    if query:
        st.subheader("πŸ“€ Answer")
        retriever = hybrid_retriever if hybrid_toggle else faiss_retriever
        response_placeholder = st.empty()
        full_response = ""
        for partial_response in generate_answer(query, retriever):
            full_response = partial_response
            response_placeholder.markdown(full_response)