Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
from flask import Flask, render_template, request, redirect, url_for, flash, send_file | |
from flask_bcrypt import Bcrypt | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
import onnxruntime | |
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering | |
from werkzeug.utils import secure_filename | |
import pandas as pd | |
from duckduckgo_search import DDGS | |
import os | |
import urllib.request | |
import gdown | |
# Initialize Flask app and Bcrypt for password hashing | |
app = Flask(__name__) | |
app.secret_key = 'your_secret_key' | |
bcrypt = Bcrypt(app) | |
models_folder = "models" | |
os.makedirs(models_folder, exist_ok=True) | |
modelx2_file_id = "1Hvt3_t8S2W5CNYUCFgd2L_KitedAJEmH" | |
trained_model_file_id = "1VCcCkj6jXBwiJAcHdAmHg_o6u32u4V7i" | |
vqa_model_file_id = "1YlUXkLx2qQMFAcT0xZ2zfXU5xx2eRFEV" | |
# Set upload folder and allowed extensions | |
app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
app.config['UPSCALED_FOLDER'] = 'static/upscaled' | |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
os.makedirs(app.config['UPSCALED_FOLDER'], exist_ok=True) | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
# Preload models and processors for efficiency | |
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
def download_model(file_id, model_path): | |
if not os.path.exists(model_path): | |
print(f"Downloading {model_path}...") | |
url = f"https://drive.google.com/uc?export=download&id={file_id}" | |
gdown.download(url, model_path, quiet=False) | |
print(f"{model_path} downloaded successfully.") | |
else: | |
print(f"{model_path} already exists.") | |
model_path = os.path.join(models_folder, "modelx2.ort") | |
caption_model_path = os.path.join(models_folder, "trained_model.pkl") | |
vqa_model_path = os.path.join(models_folder, "vqa_model.pkl") | |
download_model(modelx2_file_id, model_path) | |
download_model(trained_model_file_id, caption_model_path) | |
download_model(vqa_model_file_id, vqa_model_path) | |
# Helper functions | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def convert_pil_to_cv2(image): | |
# pil_image = image.convert("RGB") | |
open_cv_image = np.array(image) | |
# RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
return open_cv_image | |
def pre_process(img: np.array) -> np.array: | |
# H, W, C -> C, H, W | |
img = np.transpose(img[:, :, 0:3], (2, 0, 1)) | |
# C, H, W -> 1, C, H, W | |
img = np.expand_dims(img, axis=0).astype(np.float32) | |
return img | |
def post_process(img: np.array) -> np.array: | |
# 1, C, H, W -> C, H, W | |
img = np.squeeze(img) | |
# C, H, W -> H, W, C | |
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) | |
return img | |
def inference(model_path: str, img_array: np.array) -> np.array: | |
options = onnxruntime.SessionOptions() | |
options.intra_op_num_threads = 1 | |
options.inter_op_num_threads = 1 | |
ort_session = onnxruntime.InferenceSession(model_path, options) | |
ort_inputs = {ort_session.get_inputs()[0].name: img_array} | |
ort_outs = ort_session.run(None, ort_inputs) | |
return ort_outs[0] | |
def upscale(image_path: str, model="modelx2"): | |
pil_image = Image.open(image_path) | |
img = convert_pil_to_cv2(pil_image) | |
if img.ndim == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
if img.shape[2] == 4: | |
alpha = img[:, :, 3] # GRAY | |
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR | |
alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR | |
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY | |
img = img[:, :, 0:3] # BGR | |
image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA | |
image_output[:, :, 3] = alpha_output | |
elif img.shape[2] == 3: | |
image_output = post_process(inference(model_path, pre_process(img))) | |
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB) | |
return image_output | |
# Main route | |
def index(): | |
return render_template('index.html', models=["modelx2", "modelx4"]) | |
def upload_file(): | |
if 'file' not in request.files: | |
flash('Please upload an image.') | |
return redirect(url_for('index')) | |
file = request.files['file'] | |
if file and allowed_file(file.filename): | |
filename = secure_filename(file.filename) | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(filepath) | |
similar_images = [] | |
try: | |
upscaled_img = upscale(filepath) | |
upscaled_filename = f"upscaled_{filename}" | |
upscaled_path = os.path.join(app.config['UPSCALED_FOLDER'], upscaled_filename) | |
cv2.imwrite(upscaled_path, upscaled_img) | |
image = Image.open(upscaled_path).convert("RGB") | |
caption = generate_caption(image) | |
results = DDGS().images( | |
keywords=caption, | |
region="wt-wt", | |
safesearch="off", | |
size=None, | |
color="Monochrome", | |
type_image=None, | |
layout=None, | |
license_image=None, | |
max_results=100, | |
) | |
for i in results: | |
similar_images.append(i['image']) | |
image_url = url_for('serve_upscaled_file', filename=upscaled_filename) | |
return render_template('index.html',input_image_url=filepath, image_url=upscaled_path ,similar_images=similar_images, show_buttons=True) | |
except Exception as e: | |
flash(f"Upscaling failed: {e}") | |
return redirect(url_for('index')) | |
else: | |
flash('Invalid file format. Please upload a PNG, JPG, or JPEG file.') | |
return redirect(url_for('index')) | |
def process_image(): | |
image_url = os.path.basename(request.form.get('image_url')) | |
filepath = os.path.join(app.config['UPSCALED_FOLDER'], image_url) | |
print(filepath) | |
image = Image.open(filepath).convert("RGB") | |
if os.path.exists(filepath): | |
if 'vqa' in request.form: | |
question = request.form.get('question') | |
if question: | |
answer = answer_question(image, question) | |
return render_template('index.html', image_url=filepath, answer=answer, show_buttons=True, question=question) | |
else: | |
flash("Please enter a question.") | |
elif 'caption' in request.form: | |
caption = generate_caption(image) | |
return render_template('index.html', image_url=filepath, caption=caption, show_buttons=True) | |
else: | |
flash("File not found. Please re-upload the image.") | |
return redirect(url_for('index')) | |
def generate_caption(image): | |
# Process the image and prepare it for input to the model | |
inputs = caption_processor(images=image, return_tensors="pt") | |
# Generate caption (model's output is token IDs) | |
out = caption_model.generate(**inputs) | |
# Decode the generated tokens back into text (the output is a tensor of token IDs) | |
caption = caption_processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
def answer_question(image, question): | |
# Process the image and the question, prepare them for input to the model | |
inputs = vqa_processor(images=image, text=question, return_tensors="pt") | |
# Generate an answer (model's output is token IDs) | |
out = vqa_model.generate(**inputs) | |
# Decode the generated tokens back into the answer (again, output is token IDs) | |
answer = vqa_processor.decode(out[0], skip_special_tokens=True) | |
return answer | |
def serve_uploaded_file(filename): | |
return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename)) | |
def serve_upscaled_file(filename): | |
return send_file(os.path.join(app.config['UPSCALED_FOLDER'], filename)) | |
# Run app | |
if __name__ == '__main__': | |
app.run(debug=True) | |