File size: 2,708 Bytes
998f824 ffbf6d9 09ec353 ffbf6d9 62ce3e8 87cdd83 ffbf6d9 09ec353 ffbf6d9 87cdd83 ffbf6d9 ff34bbf ffbf6d9 87cdd83 ffbf6d9 09ec353 ffbf6d9 09ec353 998f824 ffbf6d9 09ec353 ffbf6d9 998f824 ffbf6d9 8ca9de9 ffbf6d9 a0ecec2 ffbf6d9 998f824 ffbf6d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from PyPDF2 import PdfReader
import gradio as gr
from datasets import Dataset, load_from_disk
# Extract text from PDF
def extract_text_from_pdf(pdf_path):
text = ""
with open(pdf_path, "rb") as f:
reader = PdfReader(f)
for page in reader.pages:
text += page.extract_text()
return text
# Load model and tokenizer
model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Extract text from the provided PDF
pdf_path = "/home/user/app/TOPF 2564.pdf" # Ensure this path is correct
pdf_text = extract_text_from_pdf(pdf_path)
passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]
# Create a Dataset
dataset = Dataset.from_dict({"title": [p["title"] for p in passages], "text": [p["text"] for p in passages]})
# Save the dataset and create an index in the current working directory
dataset_path = "/home/user/app/rag_document_dataset"
index_path = "/home/user/app/rag_document_index"
# Ensure the directory exists
os.makedirs(dataset_path, exist_ok=True)
os.makedirs(index_path, exist_ok=True)
# Save the dataset to disk and create an index
dataset.save_to_disk(dataset_path)
dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)
# Custom retriever
def retrieve(query):
# Use FAISS index to retrieve relevant passages
query_embedding = tokenizer(query, return_tensors="pt")["input_ids"]
# Perform retrieval (this is a placeholder, actual retrieval code will be more complex)
# retrieved_passages = faiss_search(query_embedding)
retrieved_passages = " ".join([passage['text'] for passage in passages]) # Simplified for demo
return retrieved_passages
# Define the chat function
def answer_question(question, context):
retrieved_context = retrieve(question)
inputs = tokenizer(question + " " + retrieved_context, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Generate the answer
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Gradio interface setup
def ask(question):
return answer_question(question, pdf_text)
demo = gr.Interface(
fn=ask,
inputs=gr.inputs.Textbox(lines=2, placeholder="Ask something..."),
outputs="text",
title="Document QA with RAG",
description="Ask questions based on the provided document."
)
if __name__ == "__main__":
demo.launch()
|