DeepScaleR1 / app.py
AdarshHF3115's picture
Added
36e9c57 verified
raw
history blame
11.1 kB
import gradio as gr
import os
import time
from typing import List, Tuple, Optional
from pathlib import Path
from threading import Thread
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
StoppingCriteria,
StoppingCriteriaList,
)
import torch
EMBEDDING_MODEL = "BAAI/bge-m3"
MODEL_NAME = "agentica-org/DeepScaleR-1.5B-Preview"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_CONTEXT_LENGTH = 8192
bnb_config = (
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
if DEVICE == "cuda"
else None
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [0]
return input_ids[0][-1] in stop_ids
def validate_file_paths(file_paths: List[str]) -> List[str]:
valid_paths = []
for path in file_paths:
try:
if Path(path).exists() and Path(path).suffix.lower() in [".pdf", ".txt"]:
valid_paths.append(path)
except (OSError, PermissionError) as e:
print(f"File validation error: {str(e)}")
return valid_paths
def load_documents(file_paths: List[str]) -> List[Document]:
documents = []
valid_paths = validate_file_paths(file_paths)
if not valid_paths:
raise ValueError("No valid PDF/TXT files found!")
for path in valid_paths:
try:
if path.endswith(".pdf"):
loader = PyPDFLoader(path)
elif path.endswith(".txt"):
loader = TextLoader(path)
docs = loader.load()
if docs:
documents.extend(docs)
except Exception as e:
print(f"Error loading {Path(path).name}: {str(e)}")
if not documents:
raise ValueError("All documents failed to load.")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=128,
length_function=len,
add_start_index=True,
separators=["\n\n", "\n", "。", " ", ""],
)
return text_splitter.split_documents(documents)
def create_vector_store(documents: List[Document]) -> FAISS:
if not documents:
raise ValueError("No documents to index.")
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": DEVICE},
encode_kwargs={"normalize_embeddings": True},
)
return FAISS.from_documents(documents, embeddings)
def initialize_deepseek_model(
vector_store: FAISS,
temperature: float = 0.7,
max_new_tokens: int = 1024,
top_k: int = 50,
) -> ConversationalRetrievalChain:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, use_fast=True, trust_remote_code=True
)
torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=1.1,
stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
batch_size=1,
return_full_text=False,
)
llm = HuggingFacePipeline(
pipeline=text_pipeline, model_kwargs={"temperature": temperature}
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer",
input_key="question",
)
return ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_store.as_retriever(
search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10}
),
memory=memory,
chain_type="stuff",
return_source_documents=True,
verbose=False,
max_tokens_limit=MAX_CONTEXT_LENGTH,
)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {str(e)}")
def format_sources(source_docs: List[Document]) -> List[Tuple[str, int]]:
sources = []
try:
for doc in source_docs[:3]:
content = doc.page_content.strip()[:500] + "..."
page = doc.metadata.get("page", 0) + 1
sources.append((content, page))
while len(sources) < 3:
sources.append(("No source found", 0))
except Exception:
return [("Source processing error", 0)] * 3
return sources
def handle_conversation(
qa_chain: Optional[ConversationalRetrievalChain],
message: str,
history: List[Tuple[str, str]],
) -> Tuple:
start_time = time.time()
if not qa_chain:
return None, "", history, *[("System Error", 0)] * 3
try:
response = qa_chain.invoke({"question": message, "chat_history": history})
answer = response["answer"].strip()
sources = format_sources(response.get("source_documents", []))
new_history = history + [(message, answer)]
elapsed = f"{(time.time() - start_time):.2f}s"
print(f"Response generated in {elapsed}")
return (
qa_chain,
"",
new_history,
*[item for sublist in sources for item in sublist],
)
except Exception as e:
error_msg = f"⚠️ Error: {str(e)}"
return qa_chain, "", history + [(message, error_msg)], *[("Error", 0)] * 3
def create_interface() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Default()) as interface:
qa_chain = gr.State()
vector_store = gr.State()
gr.Markdown(
"""
<h1 style="text-align:center; color: #ooffff;">
DeepScale R1
</h1>
<p style="text-align:center; color: #008080;">
A Safe and Strong Local RAG System by Adarsh Pandey !!
</p>
""",
elem_id="header-section",
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("### Step 1: Document Processing")
file_input = gr.Files(
file_types=[".pdf", ".txt"], file_count="multiple"
)
process_btn = gr.Button("Process Documents", variant="primary")
process_status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Step 2: Model Configuration")
with gr.Accordion("Advanced Parameters", open=False):
temp_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
)
token_slider = gr.Slider(
minimum=256,
maximum=4096,
value=1024,
step=128,
label="Response Length",
)
topk_slider = gr.Slider(
minimum=1, maximum=100, value=50, step=5, label="Top-K Sampling"
)
init_btn = gr.Button("Initialize Model", variant="primary")
model_status = gr.Textbox(label="Model Status", interactive=False)
with gr.Column(scale=1, min_width=500):
chatbot = gr.Chatbot(
label="Conversation History",
height=450,
avatar_images=["2.png", "3.png"],
)
msg_input = gr.Textbox(
label="Your Query",
placeholder="Ask a question about your documents...",
)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.ClearButton([msg_input, chatbot], value="Clear Chat")
with gr.Accordion("Source References", open=True):
for i in range(3):
with gr.Row():
gr.Textbox(
label=f"Reference {i+1}", max_lines=4, interactive=False
)
gr.Number(label="Page", value=0, interactive=False)
process_btn.click(
fn=lambda files: (
create_vector_store(load_documents([f.name for f in files])),
"Documents processed successfully.",
),
inputs=file_input,
outputs=[vector_store, process_status],
api_name="process_docs",
)
init_btn.click(
fn=lambda vs, temp, tokens, k: (
initialize_deepseek_model(vs, temp, tokens, k),
"Model initialized successfully.",
),
inputs=[vector_store, temp_slider, token_slider, topk_slider],
outputs=[qa_chain, model_status],
api_name="init_model",
)
msg_input.submit(
fn=handle_conversation,
inputs=[qa_chain, msg_input, chatbot],
outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3],
api_name="chat",
)
submit_btn.click(
fn=handle_conversation,
inputs=[qa_chain, msg_input, chatbot],
outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3],
api_name="chat",
)
return interface
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0" if os.getenv("DOCKER") else "localhost",
server_port=7860,
show_error=True,
share=True,
favicon_path="1.png",
)