File size: 6,033 Bytes
44ddcbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sqlite3
from sqlite3 import Error
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
import os
import torch
from huggingface_hub import login
import ast

class QAInfer:
    def __init__(self):
        torch.cuda.empty_cache()
        qa_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
        bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype="float16",
                bnb_4bit_use_double_quant=False,
            )
        self.qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
        self.qa_tokenizer.pad_token = self.qa_tokenizer.eos_token

        self.qa_model = AutoModelForCausalLM.from_pretrained(
                qa_model_name,
                quantization_config=bnb_config,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
                cache_dir="cache/mistral"
            )

    def extract_text_from_pdf(self, pdf_path):
        """Extract text from a PDF file."""
        reader = PdfReader(pdf_path)
        text = ''
        for page in reader.pages:
                text += page.extract_text()
        return text
    
    def qa_infer(self, query, rows, col):
        """QA inference function."""
        print(query)
       
        print(tuple(col))   
        file_index = -1

        if "additional_info" not in col:
            pass
        else:
            file_index = [i for i in range(len(col)) if col[i] == "additional_info"][0]
            initiate_qa = input("\nDo you wish to ask questions about the properties [y/n]?: ").lower()

            if initiate_qa in ['y', 'yes']:
                for row in rows:
                    pdf_text = self.extract_text_from_pdf(row[file_index])
                    print("Extracted text from PDF", os.path.basename(row[file_index]))

                    while True:
                        user_question = input("\nWhat do you want to know about this property? (Press Enter to exit): ").strip()
                        if not user_question:
                            break
                        # Construct QA prompt directly here
                        question = user_question if user_question else "Who is lashkar e taiba"
                        prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response ###question:{question} ###context:{pdf_text} ###response:"""

                        # Run the language model to generate a response
                        inputs = self.qa_tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
                        pipe = pipeline(
                            "text-generation", 
                            model=self.qa_model, 
                            tokenizer=self.qa_tokenizer, 
                            torch_dtype=torch.bfloat16, 
                            device_map="auto"
                        )
                        sequences = pipe(
                            prompt,
                            do_sample=True,
                            max_new_tokens=200, 
                            temperature=0.7, 
                            top_k=50, 
                            top_p=0.95,
                            num_return_sequences=1,
                        )
                        answer = sequences[0]['generated_text']
                        print("Answer:", answer)
            else:
                continue_to_next = input("Do you want to continue with the next property? [y/n]: ").lower()
                if continue_to_next != 'y':
                    return
    
    def qa_infer_interface(self, row, query_question):
        """This method is used for gradio interface only"""
        file_path = row[-1]  # Assuming the last element in row contains the PDF file path
        pdf_text = self.extract_text_from_pdf(file_path)
        
        # prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response , if related answer cannot be found return "Answer not in the context" ###question:{query_question} ###context:{pdf_text} ###response:"""
        prompt = f"""You have been provided with a question and a corresponding context. Your task is to search the context to find the answer to the question. If the answer is found, return the response. If the answer cannot be found in the context, please respond with "Answer not found in the context".

        === Question ===
        {query_question}

        === Context ===
        {pdf_text}

        === Response ===
        If the answer cannot be found in the context, please respond with "Answer not found in the context""""

        inputs = self.qa_tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
        pipe = pipeline(
            "text-generation", 
            model=self.qa_model, 
            tokenizer=self.qa_tokenizer, 
            torch_dtype=torch.bfloat16, 
            device_map="auto"
        )
        sequences = pipe(
            prompt,
            do_sample=True,
            max_new_tokens=200, 
            temperature=0.7, 
            top_k=50, 
            top_p=0.95,
            num_return_sequences=1,
        )
        answer = sequences[0]['generated_text']
        return answer

if __name__ == '__main__':
    qa_infer = QAInfer()
    query = 'SELECT * FROM sql_pdf WHERE country = "India" '
    rows = [
        (3, 'Taj Mahal Palace', 455, 76, 0.15, 'Mumbai', 'India', 795, 748, 67000, 'pdf_files/pdf/Taj_Mahal_palace.pdf'), 
        (6, 'Antilia', 455, 70, 0.46, 'Mumbai', 'India', 612, 2520, 179000, 'pdf_files/pdf/Antilia.pdf')
    ]
    col = [
        "property_id", "name", "bed", "bath", "acre_lot", "city", "country", 
        "zip_code", "house_size", "price", "additional_info"
    ]
    qa_infer.qa_infer(query, rows, col)