Spaces:
Sleeping
Sleeping
File size: 4,008 Bytes
b14a2f9 d2c6ac6 b14a2f9 1fda785 fb8d4f3 b14a2f9 1fda785 b14a2f9 d2c6ac6 b14a2f9 1fda785 d2c6ac6 1fda785 fb8d4f3 d2c6ac6 fb8d4f3 d2c6ac6 1fda785 fb8d4f3 1fda785 fb8d4f3 1fda785 fb8d4f3 1fda785 d2c6ac6 1fda785 d2c6ac6 1b8b1a1 d2c6ac6 b14a2f9 d2c6ac6 fb8d4f3 d2c6ac6 b14a2f9 fb8d4f3 b14a2f9 fb8d4f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import fitz # PyMuPDF for PDF text extraction
import faiss # FAISS for vector search
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
from typing import List, Tuple
from fastapi import FastAPI, Query
import uvicorn
# Default settings
class ChatConfig:
MODEL = "google/gemma-3-27b-it"
DEFAULT_SYSTEM_MSG = "You are an AI assistant answering only based on the uploaded PDF."
DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMP = 0.3
DEFAULT_TOP_P = 0.95
client = InferenceClient(ChatConfig.MODEL)
embed_model = SentenceTransformer("all-MiniLM-L6-v2") # Lightweight embedding model
vector_dim = 384 # Embedding size
index = faiss.IndexFlatL2(vector_dim) # FAISS index
documents = [] # Store extracted text
def extract_text_from_pdf(pdf_path):
"""Extracts text from PDF"""
doc = fitz.open(pdf_path)
text_chunks = [page.get_text("text") for page in doc]
return text_chunks
def create_vector_db(text_chunks):
"""Embeds text chunks and adds them to FAISS index"""
global documents, index
documents = text_chunks
embeddings = embed_model.encode(text_chunks)
index.add(np.array(embeddings, dtype=np.float32))
def search_relevant_text(query):
"""Finds the most relevant text chunk for the given query"""
query_embedding = embed_model.encode([query])
_, closest_idx = index.search(np.array(query_embedding, dtype=np.float32), k=3)
return "\n".join([documents[i] for i in closest_idx[0]])
def generate_response_sync(message: str) -> str:
"""Generates response synchronously for FastAPI"""
if not documents:
return "Please upload a PDF first."
context = search_relevant_text(message) # Get relevant content from PDF
messages = [
{"role": "system", "content": ChatConfig.DEFAULT_SYSTEM_MSG},
{"role": "user", "content": f"Context: {context}\nQuestion: {message}"}
]
response = ""
for chunk in client.chat_completion(
messages,
max_tokens=ChatConfig.DEFAULT_MAX_TOKENS,
stream=True,
temperature=ChatConfig.DEFAULT_TEMP,
top_p=ChatConfig.DEFAULT_TOP_P,
):
token = chunk.choices[0].delta.content or ""
response += token
return response
def handle_upload(pdf_file):
"""Handles PDF upload and creates vector DB"""
text_chunks = extract_text_from_pdf(pdf_file.name)
create_vector_db(text_chunks)
return "PDF uploaded and indexed successfully!"
def create_interface() -> gr.Blocks:
"""Creates the Gradio interface"""
with gr.Blocks() as interface:
gr.Markdown("# PDF-Based Chatbot using Google Gemma")
with gr.Row():
chatbot = gr.Chatbot(label="Chat with Your PDF", type="messages")
pdf_upload = gr.File(label="Upload PDF", type="filepath")
with gr.Row():
user_input = gr.Textbox(label="Ask a question", placeholder="Type here...")
send_button = gr.Button("Send")
output = gr.Textbox(label="Response", lines=5)
# Upload PDF handler
pdf_upload.change(handle_upload, inputs=[pdf_upload], outputs=[])
# Chat function
send_button.click(
generate_response_sync,
inputs=[user_input],
outputs=[output]
)
return interface
# FastAPI Integration
app = FastAPI()
@app.get("/chat")
def chat_with_pdf(msg: str = Query(..., title="User Message")):
"""API endpoint to receive a message and return AI response"""
response = generate_response_sync(msg)
return {"response": response}
if __name__ == "__main__":
import threading
# Start Gradio UI in a separate thread
def run_gradio():
gradio_app = create_interface()
gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=True)
threading.Thread(target=run_gradio).start()
# Start FastAPI
uvicorn.run(app, host="0.0.0.0", port=8000)
|