import os import pickle from flask import Flask, render_template, request, redirect, url_for, flash, send_file from flask_bcrypt import Bcrypt from PIL import Image import numpy as np import cv2 import onnxruntime from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering from werkzeug.utils import secure_filename import pandas as pd from duckduckgo_search import DDGS import os import urllib.request import gdown # Initialize Flask app and Bcrypt for password hashing app = Flask(__name__) app.secret_key = 'your_secret_key' bcrypt = Bcrypt(app) models_folder = "models" os.makedirs(models_folder, exist_ok=True) modelx2_file_id = "1Hvt3_t8S2W5CNYUCFgd2L_KitedAJEmH" trained_model_file_id = "1VCcCkj6jXBwiJAcHdAmHg_o6u32u4V7i" vqa_model_file_id = "1YlUXkLx2qQMFAcT0xZ2zfXU5xx2eRFEV" # Set upload folder and allowed extensions app.config['UPLOAD_FOLDER'] = 'static/uploads' app.config['UPSCALED_FOLDER'] = 'static/upscaled' os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['UPSCALED_FOLDER'], exist_ok=True) ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # Preload models and processors for efficiency caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") def download_model(file_id, model_path): if not os.path.exists(model_path): print(f"Downloading {model_path}...") url = f"https://drive.google.com/uc?export=download&id={file_id}" gdown.download(url, model_path, quiet=False) print(f"{model_path} downloaded successfully.") else: print(f"{model_path} already exists.") model_path = os.path.join(models_folder, "modelx2.ort") caption_model_path = os.path.join(models_folder, "trained_model.pkl") vqa_model_path = os.path.join(models_folder, "vqa_model.pkl") download_model(modelx2_file_id, model_path) download_model(trained_model_file_id, caption_model_path) download_model(vqa_model_file_id, vqa_model_path) # Helper functions def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def convert_pil_to_cv2(image): # pil_image = image.convert("RGB") open_cv_image = np.array(image) # RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() return open_cv_image def pre_process(img: np.array) -> np.array: # H, W, C -> C, H, W img = np.transpose(img[:, :, 0:3], (2, 0, 1)) # C, H, W -> 1, C, H, W img = np.expand_dims(img, axis=0).astype(np.float32) return img def post_process(img: np.array) -> np.array: # 1, C, H, W -> C, H, W img = np.squeeze(img) # C, H, W -> H, W, C img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) return img def inference(model_path: str, img_array: np.array) -> np.array: options = onnxruntime.SessionOptions() options.intra_op_num_threads = 1 options.inter_op_num_threads = 1 ort_session = onnxruntime.InferenceSession(model_path, options) ort_inputs = {ort_session.get_inputs()[0].name: img_array} ort_outs = ort_session.run(None, ort_inputs) return ort_outs[0] def upscale(image_path: str, model="modelx2"): pil_image = Image.open(image_path) img = convert_pil_to_cv2(pil_image) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if img.shape[2] == 4: alpha = img[:, :, 3] # GRAY alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY img = img[:, :, 0:3] # BGR image_output = post_process(inference(model_path, pre_process(img))) # BGR image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA image_output[:, :, 3] = alpha_output elif img.shape[2] == 3: image_output = post_process(inference(model_path, pre_process(img))) image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB) return image_output # Main route @app.route('/') def index(): return render_template('index.html', models=["modelx2", "modelx4"]) @app.route('/upload', methods=['POST']) def upload_file(): if 'file' not in request.files: flash('Please upload an image.') return redirect(url_for('index')) file = request.files['file'] if file and allowed_file(file.filename): filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) similar_images = [] try: upscaled_img = upscale(filepath) upscaled_filename = f"upscaled_{filename}" upscaled_path = os.path.join(app.config['UPSCALED_FOLDER'], upscaled_filename) cv2.imwrite(upscaled_path, upscaled_img) image = Image.open(upscaled_path).convert("RGB") caption = generate_caption(image) results = DDGS().images( keywords=caption, region="wt-wt", safesearch="off", size=None, color="Monochrome", type_image=None, layout=None, license_image=None, max_results=100, ) for i in results: similar_images.append(i['image']) image_url = url_for('serve_upscaled_file', filename=upscaled_filename) return render_template('index.html',input_image_url=filepath, image_url=upscaled_path ,similar_images=similar_images, show_buttons=True) except Exception as e: flash(f"Upscaling failed: {e}") return redirect(url_for('index')) else: flash('Invalid file format. Please upload a PNG, JPG, or JPEG file.') return redirect(url_for('index')) @app.route('/process_image', methods=['POST']) def process_image(): image_url = os.path.basename(request.form.get('image_url')) filepath = os.path.join(app.config['UPSCALED_FOLDER'], image_url) print(filepath) image = Image.open(filepath).convert("RGB") if os.path.exists(filepath): if 'vqa' in request.form: question = request.form.get('question') if question: answer = answer_question(image, question) return render_template('index.html', image_url=filepath, answer=answer, show_buttons=True, question=question) else: flash("Please enter a question.") elif 'caption' in request.form: caption = generate_caption(image) return render_template('index.html', image_url=filepath, caption=caption, show_buttons=True) else: flash("File not found. Please re-upload the image.") return redirect(url_for('index')) def generate_caption(image): # Process the image and prepare it for input to the model inputs = caption_processor(images=image, return_tensors="pt") # Generate caption (model's output is token IDs) out = caption_model.generate(**inputs) # Decode the generated tokens back into text (the output is a tensor of token IDs) caption = caption_processor.decode(out[0], skip_special_tokens=True) return caption def answer_question(image, question): # Process the image and the question, prepare them for input to the model inputs = vqa_processor(images=image, text=question, return_tensors="pt") # Generate an answer (model's output is token IDs) out = vqa_model.generate(**inputs) # Decode the generated tokens back into the answer (again, output is token IDs) answer = vqa_processor.decode(out[0], skip_special_tokens=True) return answer @app.route('/uploads/') def serve_uploaded_file(filename): return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename)) @app.route('/upscaled/') def serve_upscaled_file(filename): return send_file(os.path.join(app.config['UPSCALED_FOLDER'], filename)) # Run app if __name__ == '__main__': app.run(debug=True)