File size: 2,708 Bytes
998f824
ffbf6d9
09ec353
ffbf6d9
62ce3e8
87cdd83
ffbf6d9
 
 
 
 
 
 
 
 
 
 
 
09ec353
 
ffbf6d9
 
87cdd83
 
ffbf6d9
 
 
ff34bbf
ffbf6d9
 
87cdd83
 
ffbf6d9
09ec353
 
 
 
ffbf6d9
 
 
 
09ec353
 
 
 
 
 
 
 
998f824
ffbf6d9
 
09ec353
 
ffbf6d9
 
998f824
ffbf6d9
 
 
 
8ca9de9
ffbf6d9
 
 
a0ecec2
ffbf6d9
 
 
 
 
 
 
998f824
ffbf6d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from PyPDF2 import PdfReader
import gradio as gr
from datasets import Dataset, load_from_disk

# Extract text from PDF
def extract_text_from_pdf(pdf_path):
    text = ""
    with open(pdf_path, "rb") as f:
        reader = PdfReader(f)
        for page in reader.pages:
            text += page.extract_text()
    return text

# Load model and tokenizer
model_name = "scb10x/llama-3-typhoon-v1.5x-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Extract text from the provided PDF
pdf_path = "/home/user/app/TOPF 2564.pdf"  # Ensure this path is correct
pdf_text = extract_text_from_pdf(pdf_path)
passages = [{"title": "", "text": line} for line in pdf_text.split('\n') if line.strip()]

# Create a Dataset
dataset = Dataset.from_dict({"title": [p["title"] for p in passages], "text": [p["text"] for p in passages]})

# Save the dataset and create an index in the current working directory
dataset_path = "/home/user/app/rag_document_dataset"
index_path = "/home/user/app/rag_document_index"

# Ensure the directory exists
os.makedirs(dataset_path, exist_ok=True)
os.makedirs(index_path, exist_ok=True)

# Save the dataset to disk and create an index
dataset.save_to_disk(dataset_path)
dataset.load_from_disk(dataset_path).add_faiss_index(column="text").save(index_path)

# Custom retriever
def retrieve(query):
    # Use FAISS index to retrieve relevant passages
    query_embedding = tokenizer(query, return_tensors="pt")["input_ids"]
    # Perform retrieval (this is a placeholder, actual retrieval code will be more complex)
    # retrieved_passages = faiss_search(query_embedding)
    retrieved_passages = " ".join([passage['text'] for passage in passages])  # Simplified for demo
    return retrieved_passages

# Define the chat function
def answer_question(question, context):
    retrieved_context = retrieve(question)
    inputs = tokenizer(question + " " + retrieved_context, return_tensors="pt")
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Generate the answer
    outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

# Gradio interface setup
def ask(question):
    return answer_question(question, pdf_text)

demo = gr.Interface(
    fn=ask,
    inputs=gr.inputs.Textbox(lines=2, placeholder="Ask something..."),
    outputs="text",
    title="Document QA with RAG",
    description="Ask questions based on the provided document."
)

if __name__ == "__main__":
    demo.launch()