RAG / app.py
amiguel's picture
Update app.py
1adf5f1 verified
raw
history blame
5.93 kB
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
# --- Hugging Face Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTs the Twin")
# --- File Upload UI ---
with st.sidebar:
st.header("πŸ“„ Upload Knowledge Files")
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
if uploaded_files:
st.success(f"{len(uploaded_files)} file(s) uploaded")
# --- Load Model & Tokenizer ---
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_Qwen1.8B_Finetune", trust_remote_code=True, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"amiguel/GM_Qwen1.8B_Finetune",
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
trust_remote_code=True,
token=HF_TOKEN
)
return model, tokenizer
model, tokenizer = load_model()
# --- System Prompt ---
SYSTEM_PROMPT = (
"You are DigiTwin, a digital expert and senior topside engineer specializing in inspection and maintenance "
"of offshore piping systems, structural elements, mechanical equipment, floating production units, pressure vessels "
"(with emphasis on Visual Internal Inspection - VII), and pressure safety devices (PSDs). Rely on uploaded documents "
"and context to provide practical, standards-driven, and technically accurate responses. Your guidance reflects deep "
"field experience, industry regulations, and proven methodologies in asset integrity and reliability engineering."
)
# --- Prompt Builder ---
def build_prompt(messages, context=""):
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
for msg in messages:
role = msg["role"]
prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
# --- Embed Uploaded Documents ---
@st.cache_resource
def embed_uploaded_files(files):
raw_docs = []
for f in files:
path = f"/tmp/{f.name}"
with open(path, "wb") as out_file:
out_file.write(f.read())
loader = PyPDFLoader(path) if f.name.endswith(".pdf") else TextLoader(path)
raw_docs.extend(loader.load())
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
chunks = splitter.split_documents(raw_docs)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.from_documents(chunks, embedding=embeddings)
return db
retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None
# --- Streaming Generator ---
def generate_response(prompt_text):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
thread = Thread(target=model.generate, kwargs={
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
})
thread.start()
return streamer
# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# --- Initialize Chat Memory ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Display Message History ---
for msg in st.session_state.messages:
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
st.markdown(msg["content"])
# --- Chat Interface ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
context = ""
docs = []
if retriever:
docs = retriever.similarity_search(prompt, k=3)
context = "\n\n".join([doc.page_content for doc in docs])
# Limit to last 6 messages for memory
recent_messages = st.session_state.messages[-6:]
full_prompt = build_prompt(recent_messages, context)
with st.chat_message("assistant", avatar=BOT_AVATAR):
start = time.time()
container = st.empty()
answer = ""
for chunk in generate_response(full_prompt):
answer += chunk
container.markdown(answer + "β–Œ", unsafe_allow_html=True)
container.markdown(answer)
end = time.time()
st.session_state.messages.append({"role": "assistant", "content": answer})
input_tokens = len(tokenizer(full_prompt)["input_ids"])
output_tokens = len(tokenizer(answer)["input_ids"])
speed = output_tokens / (end - start)
with st.expander("πŸ“Š Debug Info"):
st.caption(
f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"πŸ•’ Speed: {speed:.1f} tokens/sec"
)
for i, doc in enumerate(docs):
st.markdown(f"**Chunk #{i+1}**")
st.code(doc.page_content.strip()[:500])