Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
4 |
from PyPDF2 import PdfReader
|
5 |
import gradio as gr
|
6 |
from datasets import Dataset, load_from_disk
|
|
|
7 |
|
8 |
# Extract text from PDF
|
9 |
def extract_text_from_pdf(pdf_path):
|
@@ -19,13 +20,19 @@ model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
|
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
21 |
|
|
|
|
|
|
|
22 |
# Extract text from the provided PDF
|
23 |
pdf_path = "/home/user/app/TOPF 2564.pdf" # Ensure this path is correct
|
24 |
pdf_text = extract_text_from_pdf(pdf_path)
|
25 |
passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
|
26 |
|
27 |
-
#
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
# Save the dataset and create an index in the current working directory
|
31 |
dataset_path = "/home/user/app/rag_document_dataset"
|
@@ -37,15 +44,15 @@ os.makedirs(index_path, exist_ok=True)
|
|
37 |
|
38 |
# Save the dataset to disk and create an index
|
39 |
dataset.save_to_disk(dataset_path)
|
40 |
-
dataset
|
|
|
41 |
|
42 |
# Custom retriever
|
43 |
def retrieve(query):
|
44 |
# Use FAISS index to retrieve relevant passages
|
45 |
-
query_embedding =
|
46 |
-
|
47 |
-
|
48 |
-
retrieved_passages = " ".join([passage['text'] for passage in passages]) # Simplified for demo
|
49 |
return retrieved_passages
|
50 |
|
51 |
# Define the chat function
|
|
|
4 |
from PyPDF2 import PdfReader
|
5 |
import gradio as gr
|
6 |
from datasets import Dataset, load_from_disk
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
|
9 |
# Extract text from PDF
|
10 |
def extract_text_from_pdf(pdf_path):
|
|
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
22 |
|
23 |
+
# Load a sentence transformer model for embedding generation
|
24 |
+
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
25 |
+
|
26 |
# Extract text from the provided PDF
|
27 |
pdf_path = "/home/user/app/TOPF 2564.pdf" # Ensure this path is correct
|
28 |
pdf_text = extract_text_from_pdf(pdf_path)
|
29 |
passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
|
30 |
|
31 |
+
# Convert text to embeddings
|
32 |
+
embeddings = embedding_model.encode([passage["text"] for passage in passages])
|
33 |
+
|
34 |
+
# Create a Dataset with embeddings
|
35 |
+
dataset = Dataset.from_dict({"title": [p["title"] for p in passages], "text": [p["text"] for p in passages], "embeddings": embeddings.tolist()})
|
36 |
|
37 |
# Save the dataset and create an index in the current working directory
|
38 |
dataset_path = "/home/user/app/rag_document_dataset"
|
|
|
44 |
|
45 |
# Save the dataset to disk and create an index
|
46 |
dataset.save_to_disk(dataset_path)
|
47 |
+
dataset = load_from_disk(dataset_path)
|
48 |
+
dataset.add_faiss_index(column="embeddings").save(index_path)
|
49 |
|
50 |
# Custom retriever
|
51 |
def retrieve(query):
|
52 |
# Use FAISS index to retrieve relevant passages
|
53 |
+
query_embedding = embedding_model.encode([query])
|
54 |
+
scores, samples = dataset.get_nearest_examples("embeddings", query_embedding, k=5)
|
55 |
+
retrieved_passages = " ".join([sample["text"] for sample in samples])
|
|
|
56 |
return retrieved_passages
|
57 |
|
58 |
# Define the chat function
|