Rathapoom commited on
Commit
b23b89a
·
verified ·
1 Parent(s): ff34bbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from PyPDF2 import PdfReader
5
  import gradio as gr
6
  from datasets import Dataset, load_from_disk
 
7
 
8
  # Extract text from PDF
9
  def extract_text_from_pdf(pdf_path):
@@ -19,13 +20,19 @@ model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  model = AutoModelForCausalLM.from_pretrained(model_name)
21
 
 
 
 
22
  # Extract text from the provided PDF
23
  pdf_path = "/home/user/app/TOPF 2564.pdf" # Ensure this path is correct
24
  pdf_text = extract_text_from_pdf(pdf_path)
25
  passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
26
 
27
- # Create a Dataset
28
- dataset = Dataset.from_dict({"title": [p["title"] for p in passages], "text": [p["text"] for p in passages]})
 
 
 
29
 
30
  # Save the dataset and create an index in the current working directory
31
  dataset_path = "/home/user/app/rag_document_dataset"
@@ -37,15 +44,15 @@ os.makedirs(index_path, exist_ok=True)
37
 
38
  # Save the dataset to disk and create an index
39
  dataset.save_to_disk(dataset_path)
40
- dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
 
41
 
42
  # Custom retriever
43
  def retrieve(query):
44
  # Use FAISS index to retrieve relevant passages
45
- query_embedding = tokenizer(query, return_tensors="pt")["input_ids"]
46
- # Perform retrieval (this is a placeholder, actual retrieval code will be more complex)
47
- # retrieved_passages = faiss_search(query_embedding)
48
- retrieved_passages = " ".join([passage['text'] for passage in passages]) # Simplified for demo
49
  return retrieved_passages
50
 
51
  # Define the chat function
 
4
  from PyPDF2 import PdfReader
5
  import gradio as gr
6
  from datasets import Dataset, load_from_disk
7
+ from sentence_transformers import SentenceTransformer
8
 
9
  # Extract text from PDF
10
  def extract_text_from_pdf(pdf_path):
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name)
22
 
23
+ # Load a sentence transformer model for embedding generation
24
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
25
+
26
  # Extract text from the provided PDF
27
  pdf_path = "/home/user/app/TOPF 2564.pdf" # Ensure this path is correct
28
  pdf_text = extract_text_from_pdf(pdf_path)
29
  passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
30
 
31
+ # Convert text to embeddings
32
+ embeddings = embedding_model.encode([passage["text"] for passage in passages])
33
+
34
+ # Create a Dataset with embeddings
35
+ dataset = Dataset.from_dict({"title": [p["title"] for p in passages], "text": [p["text"] for p in passages], "embeddings": embeddings.tolist()})
36
 
37
  # Save the dataset and create an index in the current working directory
38
  dataset_path = "/home/user/app/rag_document_dataset"
 
44
 
45
  # Save the dataset to disk and create an index
46
  dataset.save_to_disk(dataset_path)
47
+ dataset = load_from_disk(dataset_path)
48
+ dataset.add_faiss_index(column="embeddings").save(index_path)
49
 
50
  # Custom retriever
51
  def retrieve(query):
52
  # Use FAISS index to retrieve relevant passages
53
+ query_embedding = embedding_model.encode([query])
54
+ scores, samples = dataset.get_nearest_examples("embeddings", query_embedding, k=5)
55
+ retrieved_passages = " ".join([sample["text"] for sample in samples])
 
56
  return retrieved_passages
57
 
58
  # Define the chat function