imagechatbot / app.py
mupparajuk31's picture
Update app.py
beaaa12 verified
raw
history blame
8.4 kB
import os
import pickle
from flask import Flask, render_template, request, redirect, url_for, flash, send_file
from flask_bcrypt import Bcrypt
from PIL import Image
import numpy as np
import cv2
import onnxruntime
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
from werkzeug.utils import secure_filename
import pandas as pd
from duckduckgo_search import DDGS
import os
import urllib.request
import gdown
# Initialize Flask app and Bcrypt for password hashing
app = Flask(__name__)
app.secret_key = 'your_secret_key'
bcrypt = Bcrypt(app)
models_folder = "models"
os.makedirs(models_folder, exist_ok=True)
modelx2_file_id = "1Hvt3_t8S2W5CNYUCFgd2L_KitedAJEmH"
trained_model_file_id = "1VCcCkj6jXBwiJAcHdAmHg_o6u32u4V7i"
vqa_model_file_id = "1YlUXkLx2qQMFAcT0xZ2zfXU5xx2eRFEV"
# Set upload folder and allowed extensions
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['UPSCALED_FOLDER'] = 'static/upscaled'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['UPSCALED_FOLDER'], exist_ok=True)
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
# Preload models and processors for efficiency
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
def download_model(file_id, model_path):
if not os.path.exists(model_path):
print(f"Downloading {model_path}...")
url = f"https://drive.google.com/uc?export=download&id={file_id}"
gdown.download(url, model_path, quiet=False)
print(f"{model_path} downloaded successfully.")
else:
print(f"{model_path} already exists.")
model_path = os.path.join(models_folder, "modelx2.ort")
caption_model_path = os.path.join(models_folder, "trained_model.pkl")
vqa_model_path = os.path.join(models_folder, "vqa_model.pkl")
download_model(modelx2_file_id, model_path)
download_model(trained_model_file_id, caption_model_path)
download_model(vqa_model_file_id, vqa_model_path)
# Helper functions
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def convert_pil_to_cv2(image):
# pil_image = image.convert("RGB")
open_cv_image = np.array(image)
# RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
def pre_process(img: np.array) -> np.array:
# H, W, C -> C, H, W
img = np.transpose(img[:, :, 0:3], (2, 0, 1))
# C, H, W -> 1, C, H, W
img = np.expand_dims(img, axis=0).astype(np.float32)
return img
def post_process(img: np.array) -> np.array:
# 1, C, H, W -> C, H, W
img = np.squeeze(img)
# C, H, W -> H, W, C
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
return img
def inference(model_path: str, img_array: np.array) -> np.array:
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(model_path, options)
ort_inputs = {ort_session.get_inputs()[0].name: img_array}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def upscale(image_path: str, model="modelx2"):
pil_image = Image.open(image_path)
img = convert_pil_to_cv2(pil_image)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
alpha = img[:, :, 3] # GRAY
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
img = img[:, :, 0:3] # BGR
image_output = post_process(inference(model_path, pre_process(img))) # BGR
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
image_output[:, :, 3] = alpha_output
elif img.shape[2] == 3:
image_output = post_process(inference(model_path, pre_process(img)))
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)
return image_output
# Main route
@app.route('/')
def index():
return render_template('index.html', models=["modelx2", "modelx4"])
@app.route('/upload', methods=['POST'])
def upload_file():
if 'file' not in request.files:
flash('Please upload an image.')
return redirect(url_for('index'))
file = request.files['file']
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
similar_images = []
try:
upscaled_img = upscale(filepath)
upscaled_filename = f"upscaled_{filename}"
upscaled_path = os.path.join(app.config['UPSCALED_FOLDER'], upscaled_filename)
cv2.imwrite(upscaled_path, upscaled_img)
image = Image.open(upscaled_path).convert("RGB")
caption = generate_caption(image)
results = DDGS().images(
keywords=caption,
region="wt-wt",
safesearch="off",
size=None,
color="Monochrome",
type_image=None,
layout=None,
license_image=None,
max_results=100,
)
for i in results:
similar_images.append(i['image'])
image_url = url_for('serve_upscaled_file', filename=upscaled_filename)
return render_template('index.html',input_image_url=filepath, image_url=upscaled_path ,similar_images=similar_images, show_buttons=True)
except Exception as e:
flash(f"Upscaling failed: {e}")
return redirect(url_for('index'))
else:
flash('Invalid file format. Please upload a PNG, JPG, or JPEG file.')
return redirect(url_for('index'))
@app.route('/process_image', methods=['POST'])
def process_image():
image_url = os.path.basename(request.form.get('image_url'))
filepath = os.path.join(app.config['UPSCALED_FOLDER'], image_url)
print(filepath)
image = Image.open(filepath).convert("RGB")
if os.path.exists(filepath):
if 'vqa' in request.form:
question = request.form.get('question')
if question:
answer = answer_question(image, question)
return render_template('index.html', image_url=filepath, answer=answer, show_buttons=True, question=question)
else:
flash("Please enter a question.")
elif 'caption' in request.form:
caption = generate_caption(image)
return render_template('index.html', image_url=filepath, caption=caption, show_buttons=True)
else:
flash("File not found. Please re-upload the image.")
return redirect(url_for('index'))
def generate_caption(image):
# Process the image and prepare it for input to the model
inputs = caption_processor(images=image, return_tensors="pt")
# Generate caption (model's output is token IDs)
out = caption_model.generate(**inputs)
# Decode the generated tokens back into text (the output is a tensor of token IDs)
caption = caption_processor.decode(out[0], skip_special_tokens=True)
return caption
def answer_question(image, question):
# Process the image and the question, prepare them for input to the model
inputs = vqa_processor(images=image, text=question, return_tensors="pt")
# Generate an answer (model's output is token IDs)
out = vqa_model.generate(**inputs)
# Decode the generated tokens back into the answer (again, output is token IDs)
answer = vqa_processor.decode(out[0], skip_special_tokens=True)
return answer
@app.route('/uploads/<filename>')
def serve_uploaded_file(filename):
return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename))
@app.route('/upscaled/<filename>')
def serve_upscaled_file(filename):
return send_file(os.path.join(app.config['UPSCALED_FOLDER'], filename))
# Run app
if __name__ == '__main__':
app.run(debug=True)