ataliba / app.py
amiguel's picture
Update app.py
c941abd verified
import streamlit as st
import torch
import os
import tempfile
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document
from langchain.docstore.document import Document as LangchainDocument
# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# --- Hugging Face Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Setup ---
st.set_page_config(page_title="Hybrid RAG Chat", page_icon="πŸ€–", layout="centered")
st.title("πŸ€– DigiTwin Streaming")
# --- Sidebar Upload ---
with st.sidebar:
st.header("πŸ“€ Upload Documents")
uploaded_files = st.file_uploader("PDFs or .txt files only", type=["pdf", "txt"], accept_multiple_files=True)
max_tokens = st.slider("🧠 Max Response Tokens", 100, 2048, 512, step=50)
clear_chat = st.button("🧹 Clear Conversation")
# --- Chat Memory ---
if "messages" not in st.session_state or clear_chat:
st.session_state.messages = []
# --- Load LLM ---
@st.cache_resource
def load_model():
model_id = "amiguel/GM_Qwen1.8B_Finetune" #"tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
return tokenizer, model
tokenizer, model = load_model()
# --- Document Processing ---
def process_documents(files):
documents = []
for file in files:
suffix = ".pdf" if file.name.endswith(".pdf") else ".txt"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(file.read())
path = tmp.name
loader = PyPDFLoader(path) if suffix == ".pdf" else TextLoader(path)
documents.extend(loader.load())
return documents
def chunk_documents(docs):
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
return splitter.split_documents(docs)
def build_hybrid_retriever(chunks):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss = FAISS.from_documents(chunks, embeddings)
faiss_ret = faiss.as_retriever(search_type="similarity", search_kwargs={"k": 5})
bm25 = BM25Retriever.from_documents([LangchainDocument(page_content=c.page_content) for c in chunks])
bm25.k = 5
return EnsembleRetriever(retrievers=[faiss_ret, bm25], weights=[0.5, 0.5])
# --- Prompt Builder ---
def build_prompt(history, context=""):
dialog = ""
for msg in history:
role = "User" if msg["role"] == "user" else "Assistant"
dialog += f"{role}: {msg['content']}\n"
return f"""You are DigiTwin, a highly professional and experienced assistant in inspection, integrity, and maintenance of topside equipment, piping systems, pressure vessels, structures, and safety systems. Use the following context to provide expert-level answers.
Context:
{context}
{dialog}
Assistant:"""
# --- Response Generator ---
def generate_response(prompt, max_tokens):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
Thread(target=model.generate, kwargs={**inputs, "streamer": streamer, "max_new_tokens": max_tokens}).start()
output = ""
for token in streamer:
output += token
yield output
# --- Retrieval Logic ---
retriever = None
if uploaded_files:
with st.spinner("πŸ” Indexing documents..."):
docs = process_documents(uploaded_files)
chunks = chunk_documents(docs)
retriever = build_hybrid_retriever(chunks)
st.success("βœ… Documents ready for hybrid search.")
# --- Display Conversation ---
for msg in st.session_state.messages:
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
st.markdown(msg["content"])
# --- Chat Input ---
if query := st.chat_input("Ask DigiTwin anything..."):
st.chat_message("user", avatar=USER_AVATAR).markdown(query)
st.session_state.messages.append({"role": "user", "content": query})
context = ""
matched_chunks = []
if retriever:
matched_chunks = retriever.get_relevant_documents(query)
context = "\n\n".join([doc.page_content for doc in matched_chunks])
full_prompt = build_prompt(st.session_state.messages, context)
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
container = st.empty()
answer = ""
for chunk in generate_response(full_prompt, max_tokens):
answer = chunk
container.markdown(answer + "β–Œ", unsafe_allow_html=True)
container.markdown(answer)
end_time = time.time()
input_tokens = len(tokenizer(full_prompt)["input_ids"])
output_tokens = len(tokenizer(answer)["input_ids"])
speed = output_tokens / (end_time - start_time)
st.session_state.messages.append({"role": "assistant", "content": answer})
# RAG Debug Info
with st.expander("πŸ“Š Response Stats & RAG Debug"):
st.caption(
f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"πŸ•’ Speed: {speed:.1f} tokens/sec"
)
for i, doc in enumerate(matched_chunks):
score = getattr(doc, "score", None)
metadata = doc.metadata if hasattr(doc, "metadata") else {}
st.markdown(f"**Chunk #{i+1}**")
st.code(doc.page_content.strip()[:500])
st.text(f"πŸ” Similarity Score: {score if score else 'N/A'} | Metadata: {metadata}")