import streamlit as st import os import gc import base64 import tempfile import uuid from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.prompts import PromptTemplate from transformers import AutoTokenizer, AutoModelForCausalLM import torch # ---------------------------- # 1) LLM LOADING # ---------------------------- @st.cache_resource def load_llm(): """ Load the DeepSeek-R1 700B (approx) model from Hugging Face, using 4-bit quantization and auto device mapping. """ model_id = "deepseek-ai/DeepSeek-R1" # tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) # model in 4-bit model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, device_map="auto", # auto-shard across all available GPUs load_in_4bit=True, # bitsandbytes 4-bit quantization torch_dtype=torch.float16 ) # wrap with LlamaIndex's HuggingFaceLLM llm = HuggingFaceLLM( model=model, tokenizer=tokenizer, streaming=True, temperature=0.7, max_new_tokens=512 ) return llm # ---------------------------- # 2) STREAMLIT + INDEX SETUP # ---------------------------- if "id" not in st.session_state: st.session_state.id = uuid.uuid4() st.session_state.file_cache = {} def reset_chat(): st.session_state.messages = [] gc.collect() def display_pdf(file): st.markdown("### PDF Preview") base64_pdf = base64.b64encode(file.read()).decode("utf-8") pdf_display = f""" """ st.markdown(pdf_display, unsafe_allow_html=True) # Sidebar for file upload with st.sidebar: st.header("Add your documents!") uploaded_file = st.file_uploader("Choose a `.pdf` file", type="pdf") if uploaded_file: try: # Indexing the doc with tempfile.TemporaryDirectory() as temp_dir: file_path = os.path.join(temp_dir, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getvalue()) file_key = f"{st.session_state.id}-{uploaded_file.name}" st.write("Indexing your document...") if file_key not in st.session_state.get('file_cache', {}): if os.path.exists(temp_dir): loader = SimpleDirectoryReader( input_dir=temp_dir, required_exts=[".pdf"], recursive=True ) else: st.error("Could not find the file. Please reupload.") st.stop() docs = loader.load_data() # Load the HF-based LLM (DeepSeek-R1) llm = load_llm() # HuggingFace Embeddings for the VectorStore embed_model = HuggingFaceEmbedding( model_name="answerdotai/ModernBERT-large", trust_remote_code=True ) # create a service context service_context = ServiceContext.from_defaults( llm=llm, embed_model=embed_model ) # build the index index = VectorStoreIndex.from_documents( docs, service_context=service_context, show_progress=True ) query_engine = index.as_query_engine(streaming=True) # custom QA prompt qa_prompt_tmpl_str = ( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given the context info above, provide a concise answer.\n" "If you don't know, say 'I don't know'.\n" "Query: {query_str}\n" "Answer: " ) qa_prompt = PromptTemplate(qa_prompt_tmpl_str) query_engine.update_prompts( {"response_synthesizer:text_qa_template": qa_prompt} ) # store in session state st.session_state.file_cache[file_key] = query_engine else: query_engine = st.session_state.file_cache[file_key] st.success("Ready to Chat!") display_pdf(uploaded_file) except Exception as e: st.error(f"An error occurred: {e}") st.stop() col1, col2 = st.columns([6, 1]) with col1: st.markdown("# RAG with DeepSeek-R1 (700B)") with col2: st.button("Clear ↺", on_click=reset_chat) # Initialize chat if needed if "messages" not in st.session_state: reset_chat() # Render past messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input if prompt := st.chat_input("Ask a question about your PDF..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Retrieve the engine if uploaded_file: file_key = f"{st.session_state.id}-{uploaded_file.name}" query_engine = st.session_state.file_cache.get(file_key) else: query_engine = None # If no docs, just return a quick message if not query_engine: answer = "No documents indexed. Please upload a PDF first." st.session_state.messages.append({"role": "assistant", "content": answer}) with st.chat_message("assistant"): st.markdown(answer) else: with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" # Streaming generator from LlamaIndex streaming_response = query_engine.query(prompt) for chunk in streaming_response.response_gen: full_response += chunk message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response})