Chris4K commited on
Commit
ccb8edf
·
verified ·
1 Parent(s): dfa227d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pdfplumber
4
+ import docx
5
+ import nltk
6
+ import gradio as gr
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTextSplitter
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoTokenizer
12
+ from nltk import sent_tokenize
13
+ from typing import List, Tuple
14
+ from transformers import AutoModel, AutoTokenizer
15
+
16
+ # Ensure nltk sentence tokenizer is downloaded
17
+ nltk.download('punkt')
18
+
19
+ FILES_DIR = './files'
20
+
21
+ # Supported embedding models
22
+ MODELS = {
23
+ 'e5-base': "danielheinz/e5-base-sts-en-de",
24
+ 'multilingual-e5-base': "multilingual-e5-base",
25
+ 'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2",
26
+ 'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2",
27
+ 'gte-large': "gte-large",
28
+ 'gbert-base': "gbert-base"
29
+ }
30
+
31
+ class FileHandler:
32
+ @staticmethod
33
+ def extract_text(file_path):
34
+ ext = os.path.splitext(file_path)[-1].lower()
35
+ if ext == '.pdf':
36
+ return FileHandler._extract_from_pdf(file_path)
37
+ elif ext == '.docx':
38
+ return FileHandler._extract_from_docx(file_path)
39
+ elif ext == '.txt':
40
+ return FileHandler._extract_from_txt(file_path)
41
+ else:
42
+ raise ValueError(f"Unsupported file type: {ext}")
43
+
44
+ @staticmethod
45
+ def _extract_from_pdf(file_path):
46
+ with pdfplumber.open(file_path) as pdf:
47
+ return ' '.join([page.extract_text() for page in pdf.pages])
48
+
49
+ @staticmethod
50
+ def _extract_from_docx(file_path):
51
+ doc = docx.Document(file_path)
52
+ return ' '.join([para.text for para in doc.paragraphs])
53
+
54
+ @staticmethod
55
+ def _extract_from_txt(file_path):
56
+ with open(file_path, 'r', encoding='utf-8') as f:
57
+ return f.read()
58
+
59
+ class EmbeddingModel:
60
+ def __init__(self, model_name, max_tokens=None):
61
+ self.model = HuggingFaceEmbeddings(model_name=model_name)
62
+ self.max_tokens = max_tokens
63
+
64
+ def embed(self, text):
65
+ return self.model.embed_documents([text])
66
+
67
+ def process_files(model_name, split_strategy, chunk_size=500, overlap_size=50, max_tokens=None):
68
+ # File processing
69
+ text = ""
70
+ for file in os.listdir(FILES_DIR):
71
+ file_path = os.path.join(FILES_DIR, file)
72
+ text += FileHandler.extract_text(file_path)
73
+
74
+ # Split text
75
+ if split_strategy == 'sentence':
76
+ splitter = SentenceTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size)
77
+ else:
78
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size)
79
+
80
+ chunks = splitter.split_text(text)
81
+ model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens)
82
+ embeddings = model.embed(text)
83
+
84
+ return embeddings, chunks
85
+
86
+ def search_embeddings(query, model_name, top_k):
87
+ model = HuggingFaceEmbeddings(model_name=MODELS[model_name])
88
+ embeddings = model.embed_query(query)
89
+ return embeddings
90
+
91
+ def calculate_statistics(embeddings):
92
+ # Return time taken, token count, etc.
93
+ return {"tokens": len(embeddings), "time_taken": time.time()}
94
+
95
+ # Gradio frontend
96
+ def upload_file(file, model_name, split_strategy, chunk_size, overlap_size, max_tokens, query, top_k):
97
+ with open(os.path.join(FILES_DIR, file.name), "wb") as f:
98
+ f.write(file.read())
99
+
100
+ # Process files and get embeddings
101
+ embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens)
102
+
103
+ # Perform search
104
+ results = search_embeddings(query, model_name, top_k)
105
+
106
+ # Calculate statistics
107
+ stats = calculate_statistics(embeddings)
108
+
109
+ return {"results": results, "stats": stats}
110
+
111
+ # Gradio interface
112
+ iface = gr.Interface(
113
+ fn=upload_file,
114
+ inputs=[
115
+ gr.File(label="Upload File"),
116
+ gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"),
117
+ gr.Radio(choices=["sentence", "recursive"], label="Split Strategy"),
118
+ gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"),
119
+ gr.Slider(0, 100, step=10, value=50, label="Overlap Size"),
120
+ gr.Slider(50, 500, step=50, value=200, label="Max Tokens"),
121
+ gr.Textbox(label="Search Query"),
122
+ gr.Slider(1, 10, step=1, value=5, label="Top K")
123
+ ],
124
+ outputs="json"
125
+ )
126
+
127
+ iface.launch()
128
+