Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
|
5 |
+
from PyPDF2 import PdfReader
|
6 |
+
import google.generativeai as genai
|
7 |
+
import google.ai.generativelanguage as glm
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
from IPython.display import Markdown
|
11 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
12 |
+
from langchain.text_splitter import CharacterTextSplitter
|
13 |
+
from langchain.vectorstores import FAISS
|
14 |
+
|
15 |
+
|
16 |
+
# Used to securely store your API key
|
17 |
+
from google.colab import userdata
|
18 |
+
from IPython.display import Markdown
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
class ImageProcessor:
|
25 |
+
def __init__(self, image_path):
|
26 |
+
self.image_path = image_path
|
27 |
+
|
28 |
+
def get_caption(self, image_path):
|
29 |
+
# Implement image captioning logic here
|
30 |
+
"""
|
31 |
+
Generates a short caption for the provided image.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
image_path (str): The path to the image file.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
str: A string representing the caption for the image.
|
38 |
+
"""
|
39 |
+
image = Image.open(image_path).convert('RGB')
|
40 |
+
|
41 |
+
model_name = "Salesforce/blip-image-captioning-large"
|
42 |
+
device = "cpu" # cuda
|
43 |
+
|
44 |
+
processor = BlipProcessor.from_pretrained(model_name)
|
45 |
+
model = BlipForConditionalGeneration.from_pretrained(model_name).to(device)
|
46 |
+
|
47 |
+
inputs = processor(image, return_tensors='pt').to(device)
|
48 |
+
output = model.generate(**inputs, max_new_tokens=20)
|
49 |
+
|
50 |
+
caption = processor.decode(output[0], skip_special_tokens=True)
|
51 |
+
|
52 |
+
return caption
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def detect_objects(self, image_path):
|
57 |
+
# Implement object detection logic here
|
58 |
+
"""
|
59 |
+
Detects objects in the provided image.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
image_path (str): The path to the image file.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
str: A string with all the detected objects. Each object as '[x1, x2, y1, y2, class_name, confindence_score]'.
|
66 |
+
"""
|
67 |
+
image = Image.open(image_path).convert('RGB')
|
68 |
+
|
69 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
70 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
|
71 |
+
|
72 |
+
inputs = processor(images=image, return_tensors="pt")
|
73 |
+
outputs = model(**inputs)
|
74 |
+
|
75 |
+
# convert outputs (bounding boxes and class logits) to COCO API
|
76 |
+
# let's only keep detections with score > 0.9
|
77 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
78 |
+
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
79 |
+
|
80 |
+
detections = ""
|
81 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
82 |
+
detections += '[{}, {}, {}, {}]'.format(int(box[0]), int(box[1]), int(box[2]), int(box[3]))
|
83 |
+
detections += ' {}'.format(model.config.id2label[int(label)])
|
84 |
+
detections += ' {}\n'.format(float(score))
|
85 |
+
|
86 |
+
return detections
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def make_prompt(self, query, image_captions, objects_detections):
|
91 |
+
# Implement prompt creation logic here
|
92 |
+
escaped_captions = image_captions.replace("'", "").replace('"', "").replace("\n", " ")
|
93 |
+
escaped_objects = objects_detections.replace("'", "").replace('"', "").replace("\n", " ")
|
94 |
+
prompt = textwrap.dedent("""You are a helpful and informative bot that answers questions using text from the image captions and objects detected included below. \
|
95 |
+
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
|
96 |
+
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
|
97 |
+
strike a friendly and conversational tone. \
|
98 |
+
If the image captions or objects detected are irrelevant to the answer, you may ignore them.
|
99 |
+
QUESTION: '{query}'
|
100 |
+
IMAGE CAPTIONS: '{image_captions}'
|
101 |
+
OBJECTS DETECTED: '{objects_detected}'
|
102 |
+
|
103 |
+
ANSWER:
|
104 |
+
""").format(query=query, image_captions=escaped_captions, objects_detected=escaped_objects)
|
105 |
+
|
106 |
+
return prompt
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
def generate_answer(self, prompt):
|
111 |
+
# Implement answer generation logic here
|
112 |
+
model = genai.GenerativeModel('gemini-pro')
|
113 |
+
answer = model.generate_content(prompt)
|
114 |
+
|
115 |
+
return answer.text
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
class PDFProcessor:
|
123 |
+
def __init__(self, pdf_path):
|
124 |
+
self.pdf_path = pdf_path
|
125 |
+
|
126 |
+
def create_embedding_df(self, pdf_path):
|
127 |
+
# Implement PDF content vector store creation logic here
|
128 |
+
# Provide the path of the PDF file
|
129 |
+
pdfreader = PdfReader(pdf_path)
|
130 |
+
|
131 |
+
# Read text from PDF and divide it into smaller chunks
|
132 |
+
documents = []
|
133 |
+
for i, page in enumerate(pdfreader.pages):
|
134 |
+
content = page.extract_text()
|
135 |
+
if content:
|
136 |
+
# Create a document for each page
|
137 |
+
document = {
|
138 |
+
"Title": f"Page {i+1}", # Use the page number as the title
|
139 |
+
"Text": content
|
140 |
+
}
|
141 |
+
documents.append(document)
|
142 |
+
|
143 |
+
# Create a DataFrame from the documents
|
144 |
+
df = pd.DataFrame(documents)
|
145 |
+
|
146 |
+
# Define the model
|
147 |
+
model = 'models/embedding-001'
|
148 |
+
|
149 |
+
# Define a function to generate embeddings
|
150 |
+
def embed_fn(title, text):
|
151 |
+
return genai.embed_content(
|
152 |
+
model=model,
|
153 |
+
content=text,
|
154 |
+
task_type="retrieval_document",
|
155 |
+
title=title
|
156 |
+
)["embedding"]
|
157 |
+
|
158 |
+
# Generate embeddings for each document and store them in the DataFrame
|
159 |
+
df['Embeddings'] = df.apply(lambda row: embed_fn(row['Title'], row['Text']), axis=1)
|
160 |
+
|
161 |
+
return df
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
def find_best_passage(self, query, dataframe):
|
166 |
+
# Implement logic to find the best passage based on query
|
167 |
+
"""
|
168 |
+
Compute the distances between the query and each document in the dataframe
|
169 |
+
using the dot product.
|
170 |
+
"""
|
171 |
+
model = 'models/embedding-001'
|
172 |
+
query_embedding = genai.embed_content(model=model,
|
173 |
+
content=query,
|
174 |
+
task_type="retrieval_query")
|
175 |
+
dot_products = np.dot(np.stack(dataframe['Embeddings']), query_embedding["embedding"])
|
176 |
+
idx = np.argmax(dot_products)
|
177 |
+
# Return text from index with max value
|
178 |
+
return dataframe.iloc[idx]['Text']
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
def make_prompt(self, query, relevant_passage):
|
183 |
+
# Implement prompt creation logic for PDF processing
|
184 |
+
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
|
185 |
+
prompt = textwrap.dedent("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
|
186 |
+
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
|
187 |
+
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
|
188 |
+
strike a friendly and converstional tone. \
|
189 |
+
If the passage is irrelevant to the answer, you may ignore it.
|
190 |
+
QUESTION: '{query}'
|
191 |
+
PASSAGE: '{relevant_passage}'
|
192 |
+
|
193 |
+
ANSWER:
|
194 |
+
""").format(query=query, relevant_passage=escaped)
|
195 |
+
|
196 |
+
return prompt
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
def generate_answer(self, prompt):
|
201 |
+
# Implement answer generation logic for PDF processing
|
202 |
+
model = genai.GenerativeModel('gemini-pro')
|
203 |
+
answer = model.generate_content(prompt)
|
204 |
+
|
205 |
+
return answer.text
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|