Nick White commited on
Commit
aa1c44a
·
1 Parent(s): 1660dbb

ADD initial files

Browse files
Files changed (2) hide show
  1. app.py +205 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import gc
4
+ import base64
5
+ import tempfile
6
+ import uuid
7
+
8
+ from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+ from llama_index.llms.huggingface import HuggingFaceLLM
11
+ from llama_index.prompts import PromptTemplate
12
+
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ import torch
15
+
16
+ # ----------------------------
17
+ # 1) LLM LOADING
18
+ # ----------------------------
19
+ @st.cache_resource
20
+ def load_llm():
21
+ """
22
+ Load the DeepSeek-R1 700B (approx) model from Hugging Face,
23
+ using 4-bit quantization and auto device mapping.
24
+ """
25
+ model_id = "deepseek-ai/DeepSeek-R1"
26
+
27
+ # tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_id,
30
+ trust_remote_code=True
31
+ )
32
+
33
+ # model in 4-bit
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ trust_remote_code=True,
37
+ device_map="auto", # auto-shard across all available GPUs
38
+ load_in_4bit=True, # bitsandbytes 4-bit quantization
39
+ torch_dtype=torch.float16
40
+ )
41
+
42
+ # wrap with LlamaIndex's HuggingFaceLLM
43
+ llm = HuggingFaceLLM(
44
+ model=model,
45
+ tokenizer=tokenizer,
46
+ streaming=True,
47
+ temperature=0.7,
48
+ max_new_tokens=512
49
+ )
50
+ return llm
51
+
52
+ # ----------------------------
53
+ # 2) STREAMLIT + INDEX SETUP
54
+ # ----------------------------
55
+ if "id" not in st.session_state:
56
+ st.session_state.id = uuid.uuid4()
57
+ st.session_state.file_cache = {}
58
+
59
+ def reset_chat():
60
+ st.session_state.messages = []
61
+ gc.collect()
62
+
63
+ def display_pdf(file):
64
+ st.markdown("### PDF Preview")
65
+ base64_pdf = base64.b64encode(file.read()).decode("utf-8")
66
+ pdf_display = f"""
67
+ <iframe src="data:application/pdf;base64,{base64_pdf}"
68
+ width="400" height="100%"
69
+ style="height:100vh; width:100%">
70
+ </iframe>
71
+ """
72
+ st.markdown(pdf_display, unsafe_allow_html=True)
73
+
74
+ # Sidebar for file upload
75
+ with st.sidebar:
76
+ st.header("Add your documents!")
77
+
78
+ uploaded_file = st.file_uploader("Choose a `.pdf` file", type="pdf")
79
+
80
+ if uploaded_file:
81
+ try:
82
+ # Indexing the doc
83
+ with tempfile.TemporaryDirectory() as temp_dir:
84
+ file_path = os.path.join(temp_dir, uploaded_file.name)
85
+ with open(file_path, "wb") as f:
86
+ f.write(uploaded_file.getvalue())
87
+
88
+ file_key = f"{st.session_state.id}-{uploaded_file.name}"
89
+ st.write("Indexing your document...")
90
+
91
+ if file_key not in st.session_state.get('file_cache', {}):
92
+ if os.path.exists(temp_dir):
93
+ loader = SimpleDirectoryReader(
94
+ input_dir=temp_dir,
95
+ required_exts=[".pdf"],
96
+ recursive=True
97
+ )
98
+ else:
99
+ st.error("Could not find the file. Please reupload.")
100
+ st.stop()
101
+
102
+ docs = loader.load_data()
103
+
104
+ # Load the HF-based LLM (DeepSeek-R1)
105
+ llm = load_llm()
106
+
107
+ # HuggingFace Embeddings for the VectorStore
108
+ embed_model = HuggingFaceEmbedding(
109
+ model_name="answerdotai/ModernBERT-large",
110
+ trust_remote_code=True
111
+ )
112
+
113
+ # create a service context
114
+ service_context = ServiceContext.from_defaults(
115
+ llm=llm,
116
+ embed_model=embed_model
117
+ )
118
+
119
+ # build the index
120
+ index = VectorStoreIndex.from_documents(
121
+ docs,
122
+ service_context=service_context,
123
+ show_progress=True
124
+ )
125
+
126
+ query_engine = index.as_query_engine(streaming=True)
127
+
128
+ # custom QA prompt
129
+ qa_prompt_tmpl_str = (
130
+ "Context information is below.\n"
131
+ "---------------------\n"
132
+ "{context_str}\n"
133
+ "---------------------\n"
134
+ "Given the context info above, provide a concise answer.\n"
135
+ "If you don't know, say 'I don't know'.\n"
136
+ "Query: {query_str}\n"
137
+ "Answer: "
138
+ )
139
+ qa_prompt = PromptTemplate(qa_prompt_tmpl_str)
140
+ query_engine.update_prompts(
141
+ {"response_synthesizer:text_qa_template": qa_prompt}
142
+ )
143
+
144
+ # store in session state
145
+ st.session_state.file_cache[file_key] = query_engine
146
+ else:
147
+ query_engine = st.session_state.file_cache[file_key]
148
+
149
+ st.success("Ready to Chat!")
150
+ display_pdf(uploaded_file)
151
+
152
+ except Exception as e:
153
+ st.error(f"An error occurred: {e}")
154
+ st.stop()
155
+
156
+ col1, col2 = st.columns([6, 1])
157
+ with col1:
158
+ st.markdown("# RAG with DeepSeek-R1 (700B)")
159
+
160
+ with col2:
161
+ st.button("Clear ↺", on_click=reset_chat)
162
+
163
+ # Initialize chat if needed
164
+ if "messages" not in st.session_state:
165
+ reset_chat()
166
+
167
+ # Render past messages
168
+ for message in st.session_state.messages:
169
+ with st.chat_message(message["role"]):
170
+ st.markdown(message["content"])
171
+
172
+ # Chat input
173
+ if prompt := st.chat_input("Ask a question about your PDF..."):
174
+ st.session_state.messages.append({"role": "user", "content": prompt})
175
+ with st.chat_message("user"):
176
+ st.markdown(prompt)
177
+
178
+ # Retrieve the engine
179
+ if uploaded_file:
180
+ file_key = f"{st.session_state.id}-{uploaded_file.name}"
181
+ query_engine = st.session_state.file_cache.get(file_key)
182
+ else:
183
+ query_engine = None
184
+
185
+ # If no docs, just return a quick message
186
+ if not query_engine:
187
+ answer = "No documents indexed. Please upload a PDF first."
188
+ st.session_state.messages.append({"role": "assistant", "content": answer})
189
+ with st.chat_message("assistant"):
190
+ st.markdown(answer)
191
+ else:
192
+ with st.chat_message("assistant"):
193
+ message_placeholder = st.empty()
194
+ full_response = ""
195
+
196
+ # Streaming generator from LlamaIndex
197
+ streaming_response = query_engine.query(prompt)
198
+ for chunk in streaming_response.response_gen:
199
+ full_response += chunk
200
+ message_placeholder.markdown(full_response + "▌")
201
+
202
+ message_placeholder.markdown(full_response)
203
+
204
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
205
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ llama-index
3
+ transformers>=4.30.2
4
+ accelerate>=0.20.3
5
+ sentencepiece
6
+ bitsandbytes