import gradio as gr import os import time from typing import List, Tuple, Optional from pathlib import Path from threading import Thread from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader, TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from langchain.memory import ConversationBufferMemory from langchain.docstore.document import Document from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList, ) import torch EMBEDDING_MODEL = "BAAI/bge-m3" MODEL_NAME = "agentica-org/DeepScaleR-1.5B-Preview" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_CONTEXT_LENGTH = 8192 bnb_config = ( BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) if DEVICE == "cuda" else None ) class StopOnTokens(StoppingCriteria): def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: stop_ids = [0] return input_ids[0][-1] in stop_ids def validate_file_paths(file_paths: List[str]) -> List[str]: valid_paths = [] for path in file_paths: try: if Path(path).exists() and Path(path).suffix.lower() in [".pdf", ".txt"]: valid_paths.append(path) except (OSError, PermissionError) as e: print(f"File validation error: {str(e)}") return valid_paths def load_documents(file_paths: List[str]) -> List[Document]: documents = [] valid_paths = validate_file_paths(file_paths) if not valid_paths: raise ValueError("No valid PDF/TXT files found!") for path in valid_paths: try: if path.endswith(".pdf"): loader = PyPDFLoader(path) elif path.endswith(".txt"): loader = TextLoader(path) docs = loader.load() if docs: documents.extend(docs) except Exception as e: print(f"Error loading {Path(path).name}: {str(e)}") if not documents: raise ValueError("All documents failed to load.") text_splitter = RecursiveCharacterTextSplitter( chunk_size=1024, chunk_overlap=128, length_function=len, add_start_index=True, separators=["\n\n", "\n", "。", " ", ""], ) return text_splitter.split_documents(documents) def create_vector_store(documents: List[Document]) -> FAISS: if not documents: raise ValueError("No documents to index.") embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={"device": DEVICE}, encode_kwargs={"normalize_embeddings": True}, ) return FAISS.from_documents(documents, embeddings) def initialize_deepseek_model( vector_store: FAISS, temperature: float = 0.7, max_new_tokens: int = 1024, top_k: int = 50, ) -> ConversationalRetrievalChain: try: tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, use_fast=True, trust_remote_code=True ) torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto" if DEVICE == "cuda" else None, torch_dtype=torch_dtype, trust_remote_code=True, ) text_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=1.1, stopping_criteria=StoppingCriteriaList([StopOnTokens()]), batch_size=1, return_full_text=False, ) llm = HuggingFacePipeline( pipeline=text_pipeline, model_kwargs={"temperature": temperature} ) memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer", input_key="question", ) return ConversationalRetrievalChain.from_llm( llm=llm, retriever=vector_store.as_retriever( search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10} ), memory=memory, chain_type="stuff", return_source_documents=True, verbose=False, max_tokens_limit=MAX_CONTEXT_LENGTH, ) except Exception as e: raise RuntimeError(f"Model initialization failed: {str(e)}") def format_sources(source_docs: List[Document]) -> List[Tuple[str, int]]: sources = [] try: for doc in source_docs[:3]: content = doc.page_content.strip()[:500] + "..." page = doc.metadata.get("page", 0) + 1 sources.append((content, page)) while len(sources) < 3: sources.append(("No source found", 0)) except Exception: return [("Source processing error", 0)] * 3 return sources def handle_conversation( qa_chain: Optional[ConversationalRetrievalChain], message: str, history: List[Tuple[str, str]], ) -> Tuple: start_time = time.time() if not qa_chain: return None, "", history, *[("System Error", 0)] * 3 try: response = qa_chain.invoke({"question": message, "chat_history": history}) answer = response["answer"].strip() sources = format_sources(response.get("source_documents", [])) new_history = history + [(message, answer)] elapsed = f"{(time.time() - start_time):.2f}s" print(f"Response generated in {elapsed}") return ( qa_chain, "", new_history, *[item for sublist in sources for item in sublist], ) except Exception as e: error_msg = f"⚠️ Error: {str(e)}" return qa_chain, "", history + [(message, error_msg)], *[("Error", 0)] * 3 def create_interface() -> gr.Blocks: with gr.Blocks(theme=gr.themes.Default()) as interface: qa_chain = gr.State() vector_store = gr.State() gr.Markdown( """

DeepScale R1

A Safe and Strong Local RAG System by Adarsh Pandey !!

""", elem_id="header-section", ) with gr.Row(): with gr.Column(scale=1, min_width=300): gr.Markdown("### Step 1: Document Processing") file_input = gr.Files( file_types=[".pdf", ".txt"], file_count="multiple" ) process_btn = gr.Button("Process Documents", variant="primary") process_status = gr.Textbox(label="Status", interactive=False) gr.Markdown("### Step 2: Model Configuration") with gr.Accordion("Advanced Parameters", open=False): temp_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature", ) token_slider = gr.Slider( minimum=256, maximum=4096, value=1024, step=128, label="Response Length", ) topk_slider = gr.Slider( minimum=1, maximum=100, value=50, step=5, label="Top-K Sampling" ) init_btn = gr.Button("Initialize Model", variant="primary") model_status = gr.Textbox(label="Model Status", interactive=False) with gr.Column(scale=1, min_width=500): chatbot = gr.Chatbot( label="Conversation History", height=450, avatar_images=["2.png", "3.png"], ) msg_input = gr.Textbox( label="Your Query", placeholder="Ask a question about your documents...", ) with gr.Row(): submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.ClearButton([msg_input, chatbot], value="Clear Chat") with gr.Accordion("Source References", open=True): for i in range(3): with gr.Row(): gr.Textbox( label=f"Reference {i+1}", max_lines=4, interactive=False ) gr.Number(label="Page", value=0, interactive=False) process_btn.click( fn=lambda files: ( create_vector_store(load_documents([f.name for f in files])), "Documents processed successfully.", ), inputs=file_input, outputs=[vector_store, process_status], api_name="process_docs", ) init_btn.click( fn=lambda vs, temp, tokens, k: ( initialize_deepseek_model(vs, temp, tokens, k), "Model initialized successfully.", ), inputs=[vector_store, temp_slider, token_slider, topk_slider], outputs=[qa_chain, model_status], api_name="init_model", ) msg_input.submit( fn=handle_conversation, inputs=[qa_chain, msg_input, chatbot], outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3], api_name="chat", ) submit_btn.click( fn=handle_conversation, inputs=[qa_chain, msg_input, chatbot], outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3], api_name="chat", ) return interface if __name__ == "__main__": app = create_interface() app.launch( server_name="0.0.0.0" if os.getenv("DOCKER") else "localhost", server_port=7860, show_error=True, share=True, favicon_path="1.png", )