Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pdfplumber | |
import docx | |
import nltk | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTextSplitter | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer | |
from nltk import sent_tokenize | |
from typing import List, Tuple | |
from transformers import AutoModel, AutoTokenizer | |
import spacy | |
spacy.cli.download("en_core_web_sm") # Ensure the model is available | |
nlp = spacy.load("en_core_web_sm") # Load the model | |
# Ensure nltk sentence tokenizer is downloaded | |
nltk.download('punkt') | |
FILES_DIR = './files' | |
# Supported embedding models | |
MODELS = { | |
'e5-base': "danielheinz/e5-base-sts-en-de", | |
'multilingual-e5-base': "multilingual-e5-base", | |
'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2", | |
'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2", | |
'gte-large': "gte-large", | |
'gbert-base': "gbert-base" | |
} | |
class FileHandler: | |
def extract_text(file_path): | |
ext = os.path.splitext(file_path)[-1].lower() | |
if ext == '.pdf': | |
return FileHandler._extract_from_pdf(file_path) | |
elif ext == '.docx': | |
return FileHandler._extract_from_docx(file_path) | |
elif ext == '.txt': | |
return FileHandler._extract_from_txt(file_path) | |
else: | |
raise ValueError(f"Unsupported file type: {ext}") | |
def _extract_from_pdf(file_path): | |
with pdfplumber.open(file_path) as pdf: | |
return ' '.join([page.extract_text() for page in pdf.pages]) | |
def _extract_from_docx(file_path): | |
doc = docx.Document(file_path) | |
return ' '.join([para.text for para in doc.paragraphs]) | |
def _extract_from_txt(file_path): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return f.read() | |
class EmbeddingModel: | |
def __init__(self, model_name, max_tokens=None): | |
self.model = HuggingFaceEmbeddings(model_name=model_name) | |
self.max_tokens = max_tokens | |
def embed(self, text): | |
return self.model.embed_documents([text]) | |
def process_files(model_name, split_strategy, chunk_size=500, overlap_size=50, max_tokens=None): | |
# File processing | |
text = "" | |
for file in os.listdir(FILES_DIR): | |
file_path = os.path.join(FILES_DIR, file) | |
text += FileHandler.extract_text(file_path) | |
# Split text | |
if split_strategy == 'sentence': | |
splitter = SentenceTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) | |
else: | |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) | |
chunks = splitter.split_text(text) | |
model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens) | |
embeddings = model.embed(text) | |
return embeddings, chunks | |
def search_embeddings(query, model_name, top_k): | |
model = HuggingFaceEmbeddings(model_name=MODELS[model_name]) | |
embeddings = model.embed_query(query) | |
return embeddings | |
def calculate_statistics(embeddings): | |
# Return time taken, token count, etc. | |
return {"tokens": len(embeddings), "time_taken": time.time()} | |
# Gradio frontend | |
def upload_file(file, model_name, split_strategy, chunk_size, overlap_size, max_tokens, query, top_k): | |
with open(os.path.join(FILES_DIR, file.name), "wb") as f: | |
f.write(file.read()) | |
# Process files and get embeddings | |
embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens) | |
# Perform search | |
results = search_embeddings(query, model_name, top_k) | |
# Calculate statistics | |
stats = calculate_statistics(embeddings) | |
return {"results": results, "stats": stats} | |
# Gradio interface | |
iface = gr.Interface( | |
fn=upload_file, | |
inputs=[ | |
gr.File(label="Upload File"), | |
gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"), | |
gr.Radio(choices=["sentence", "recursive"], label="Split Strategy"), | |
gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), | |
gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), | |
gr.Slider(50, 500, step=50, value=200, label="Max Tokens"), | |
gr.Textbox(label="Search Query"), | |
gr.Slider(1, 10, step=1, value=5, label="Top K") | |
], | |
outputs="json" | |
) | |
iface.launch() | |