arjunanand13 commited on
Commit
44ddcbd
·
verified ·
1 Parent(s): aa3f8e4

Create qa_bot2.py

Browse files
Files changed (1) hide show
  1. qa_bot2.py +142 -0
qa_bot2.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from sqlite3 import Error
3
+ from PyPDF2 import PdfReader
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
5
+ import os
6
+ import torch
7
+ from huggingface_hub import login
8
+ import ast
9
+
10
+ class QAInfer:
11
+ def __init__(self):
12
+ torch.cuda.empty_cache()
13
+ qa_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
14
+ bnb_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_compute_dtype="float16",
18
+ bnb_4bit_use_double_quant=False,
19
+ )
20
+ self.qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
21
+ self.qa_tokenizer.pad_token = self.qa_tokenizer.eos_token
22
+
23
+ self.qa_model = AutoModelForCausalLM.from_pretrained(
24
+ qa_model_name,
25
+ quantization_config=bnb_config,
26
+ torch_dtype=torch.bfloat16,
27
+ device_map="auto",
28
+ trust_remote_code=True,
29
+ cache_dir="cache/mistral"
30
+ )
31
+
32
+ def extract_text_from_pdf(self, pdf_path):
33
+ """Extract text from a PDF file."""
34
+ reader = PdfReader(pdf_path)
35
+ text = ''
36
+ for page in reader.pages:
37
+ text += page.extract_text()
38
+ return text
39
+
40
+ def qa_infer(self, query, rows, col):
41
+ """QA inference function."""
42
+ print(query)
43
+
44
+ print(tuple(col))
45
+ file_index = -1
46
+
47
+ if "additional_info" not in col:
48
+ pass
49
+ else:
50
+ file_index = [i for i in range(len(col)) if col[i] == "additional_info"][0]
51
+ initiate_qa = input("\nDo you wish to ask questions about the properties [y/n]?: ").lower()
52
+
53
+ if initiate_qa in ['y', 'yes']:
54
+ for row in rows:
55
+ pdf_text = self.extract_text_from_pdf(row[file_index])
56
+ print("Extracted text from PDF", os.path.basename(row[file_index]))
57
+
58
+ while True:
59
+ user_question = input("\nWhat do you want to know about this property? (Press Enter to exit): ").strip()
60
+ if not user_question:
61
+ break
62
+ # Construct QA prompt directly here
63
+ question = user_question if user_question else "Who is lashkar e taiba"
64
+ prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response ###question:{question} ###context:{pdf_text} ###response:"""
65
+
66
+ # Run the language model to generate a response
67
+ inputs = self.qa_tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
68
+ pipe = pipeline(
69
+ "text-generation",
70
+ model=self.qa_model,
71
+ tokenizer=self.qa_tokenizer,
72
+ torch_dtype=torch.bfloat16,
73
+ device_map="auto"
74
+ )
75
+ sequences = pipe(
76
+ prompt,
77
+ do_sample=True,
78
+ max_new_tokens=200,
79
+ temperature=0.7,
80
+ top_k=50,
81
+ top_p=0.95,
82
+ num_return_sequences=1,
83
+ )
84
+ answer = sequences[0]['generated_text']
85
+ print("Answer:", answer)
86
+ else:
87
+ continue_to_next = input("Do you want to continue with the next property? [y/n]: ").lower()
88
+ if continue_to_next != 'y':
89
+ return
90
+
91
+ def qa_infer_interface(self, row, query_question):
92
+ """This method is used for gradio interface only"""
93
+ file_path = row[-1] # Assuming the last element in row contains the PDF file path
94
+ pdf_text = self.extract_text_from_pdf(file_path)
95
+
96
+ # prompt = f"""Below is a question and context, search the context to find the answer for the question and return the response , if related answer cannot be found return "Answer not in the context" ###question:{query_question} ###context:{pdf_text} ###response:"""
97
+ prompt = f"""You have been provided with a question and a corresponding context. Your task is to search the context to find the answer to the question. If the answer is found, return the response. If the answer cannot be found in the context, please respond with "Answer not found in the context".
98
+
99
+ === Question ===
100
+ {query_question}
101
+
102
+ === Context ===
103
+ {pdf_text}
104
+
105
+ === Response ===
106
+ If the answer cannot be found in the context, please respond with "Answer not found in the context""""
107
+
108
+ inputs = self.qa_tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
109
+ pipe = pipeline(
110
+ "text-generation",
111
+ model=self.qa_model,
112
+ tokenizer=self.qa_tokenizer,
113
+ torch_dtype=torch.bfloat16,
114
+ device_map="auto"
115
+ )
116
+ sequences = pipe(
117
+ prompt,
118
+ do_sample=True,
119
+ max_new_tokens=200,
120
+ temperature=0.7,
121
+ top_k=50,
122
+ top_p=0.95,
123
+ num_return_sequences=1,
124
+ )
125
+ answer = sequences[0]['generated_text']
126
+ return answer
127
+
128
+ if __name__ == '__main__':
129
+ qa_infer = QAInfer()
130
+ query = 'SELECT * FROM sql_pdf WHERE country = "India" '
131
+ rows = [
132
+ (3, 'Taj Mahal Palace', 455, 76, 0.15, 'Mumbai', 'India', 795, 748, 67000, 'pdf_files/pdf/Taj_Mahal_palace.pdf'),
133
+ (6, 'Antilia', 455, 70, 0.46, 'Mumbai', 'India', 612, 2520, 179000, 'pdf_files/pdf/Antilia.pdf')
134
+ ]
135
+ col = [
136
+ "property_id", "name", "bed", "bath", "acre_lot", "city", "country",
137
+ "zip_code", "house_size", "price", "additional_info"
138
+ ]
139
+ qa_infer.qa_infer(query, rows, col)
140
+
141
+
142
+