|
|
|
import streamlit as st |
|
import torch |
|
import os |
|
import time |
|
from threading import Thread |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from langchain_community.document_loaders import PyPDFLoader, TextLoader |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.retrievers import BM25Retriever |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain.schema import Document |
|
from langchain.docstore.document import Document as LangchainDocument |
|
|
|
|
|
HF_TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
|
|
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered") |
|
st.title("π DigiTs the Twin") |
|
|
|
|
|
with st.sidebar: |
|
st.header("π Upload Knowledge Files") |
|
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"]) |
|
hybrid_toggle = st.checkbox("π Enable Hybrid Search", value=True) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_id = "tiiuae/falcon-7b-instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN) |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
def process_documents(files): |
|
documents = [] |
|
for file in files: |
|
if file.name.endswith(".pdf"): |
|
loader = PyPDFLoader(file) |
|
else: |
|
loader = TextLoader(file) |
|
docs = loader.load() |
|
documents.extend(docs) |
|
return documents |
|
|
|
def chunk_documents(documents): |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
return splitter.split_documents(documents) |
|
|
|
|
|
def build_retrievers(chunks): |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
faiss_vectorstore = FAISS.from_documents(chunks, embeddings) |
|
faiss_retriever = faiss_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
|
|
|
bm25_retriever = BM25Retriever.from_documents([LangchainDocument(page_content=d.page_content) for d in chunks]) |
|
bm25_retriever.k = 5 |
|
|
|
ensemble = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5]) |
|
return faiss_retriever, ensemble |
|
|
|
|
|
def generate_answer(query, retriever): |
|
docs = retriever.get_relevant_documents(query) |
|
context = "\n".join([doc.page_content for doc in docs]) |
|
|
|
system_prompt = ( |
|
"You are DigiTwin, an expert advisor in asset integrity, reliability, inspection, and maintenance " |
|
"of topside piping, structural, mechanical systems, floating units, pressure vessels (VII), and pressure safety devices (PSD's). " |
|
"Use the context below to answer professionally.\n\nContext:\n" + context + "\n\nQuery: " + query + "\nAnswer:" |
|
) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
inputs = tokenizer(system_prompt, return_tensors="pt").to(model.device) |
|
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=300) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
answer = "" |
|
for token in streamer: |
|
answer += token |
|
yield answer |
|
|
|
|
|
if uploaded_files: |
|
with st.spinner("Processing documents..."): |
|
docs = process_documents(uploaded_files) |
|
chunks = chunk_documents(docs) |
|
faiss_retriever, hybrid_retriever = build_retrievers(chunks) |
|
st.success("Documents processed successfully.") |
|
|
|
query = st.text_input("π Ask a question based on the uploaded documents") |
|
if query: |
|
st.subheader("π€ Answer") |
|
retriever = hybrid_retriever if hybrid_toggle else faiss_retriever |
|
response_placeholder = st.empty() |
|
full_response = "" |
|
for partial_response in generate_answer(query, retriever): |
|
full_response = partial_response |
|
response_placeholder.markdown(full_response) |
|
|