Spaces:
Runtime error
Runtime error
File size: 3,803 Bytes
409f81b 2c02a9e 409f81b 84f3457 8ceb607 47ecda0 f812db9 2c02a9e f812db9 2c02a9e f812db9 2c02a9e f812db9 2c02a9e f812db9 409f81b f812db9 409f81b f812db9 409f81b 2c02a9e 47ecda0 f812db9 409f81b f812db9 409f81b 261cad3 ba470cd f812db9 ba470cd 409f81b 2c02a9e 8ceb607 6e6d28c 70fd172 f812db9 ba470cd f812db9 6e6d28c ba470cd 6e6d28c 2c02a9e f812db9 2c02a9e ba470cd f812db9 ba470cd 2c02a9e f812db9 2c02a9e f812db9 2c02a9e 261cad3 ba470cd 6e6d28c 70fd172 47ecda0 0385c04 84f3457 409f81b d7100c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import fitz
from docx import Document
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import gradio as gr
from typing import List
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from nltk.tokenize import sent_tokenize # Import for sentence segmentation
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Function to extract text from a PDF file (same as before)
def extract_text_from_pdf(pdf_path):
# ...
# Function to extract text from a Word document (same as before)
def extract_text_from_docx(docx_path):
# ...
# Initialize the embedding model (same as before)
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Hugging Face API token (same as before)
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
# Define RAG models (replace with your chosen models)
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base" # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
# Load or create FAISS index (same as before)
index_path = "faiss_index.pkl"
document_texts_path = "document_texts.pkl"
document_texts = []
# ... (rest of the FAISS index loading logic)
def preprocess_text(text):
# ... (text preprocessing logic, same as before)
def upload_files(files):
global index, document_texts
try:
for file_path in files:
# ... (file processing logic, same as before)
# Preprocess text (call the new function)
sentences = preprocess_text(text)
# Encode sentences and add to FAISS index
embeddings = embedding_model.encode(sentences)
index.add(np.array(embeddings))
# Save the updated index and documents (same as before)
# ...
return "Files processed successfully"
except Exception as e:
print(f"Error processing files: {e}")
return f"Error processing files: {e}"
def query_text(text):
try:
# Preprocess query text
query_sentences = preprocess_text(text)
query_embeddings = embedding_model.encode(query_sentences)
# Retrieve relevant documents using FAISS
D, I = index.search(np.array(query_embeddings), k=5)
retrieved_docs = [document_texts[idx] for idx in I[0] if idx != -1]
# Retriever-Augmented Generation (RAG)
retriever_inputs = retriever_tokenizer(
text=retrieved_docs, return_tensors="pt", padding=True
)
retriever_outputs = retriever(**retriever_inputs)
retrieved_texts = retriever_tokenizer.batch_decode(retriever_outputs.logits)
# Generate response using retrieved information (as prompts/context)
generator_inputs = generator_tokenizer(
text=[text] + retrieved_texts, return_tensors="pt", padding=True
)
generator_outputs = generator(**generator_inputs)
response = generator_tokenizer.decode(generator_outputs.sequences[0], skip_special_tokens=True)
return response
except Exception as e:
print(f"Error querying text: {e}")
return f"Error querying text: {e}"
# Create Gradio interface
with gr.Blocks() as demo:
# ... (rest of the Gradio interface definition)
query_button.click(fn=query_text, inputs
|