Spaces:
Runtime error
Runtime error
import argparse | |
import base64 | |
import io | |
import os | |
import re | |
import sys | |
import traceback | |
import uuid | |
from typing import List, Optional | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import pinecone | |
import pyiqa | |
import timm | |
import torch | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
from PIL import Image, ImageEnhance | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer, util | |
load_dotenv() | |
pinecone.init(api_key=os.getenv("PINECONE_KEY"), environment=os.getenv("PINECONE_ENV")) | |
IMAGE_SIMILARITY_DEMO = "/find-similar-image/" | |
IMAGE_SIMILARITY_PINECONE_DEMO = "/find-similar-image-pinecone/" | |
INDEX_NAME = "imagesearch-demo" | |
INDEX_DIMENSION = 512 | |
TMP_DIR = "tmp" | |
image_sim_model = SentenceTransformer("clip-ViT-B-32") | |
def enhance_image(pil_image): | |
# Convert PIL Image to OpenCV format | |
open_cv_image = np.array(pil_image) | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
# Convert to grayscale | |
gray = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2GRAY) | |
# Histogram equalization | |
equ = cv2.equalizeHist(gray) | |
# Adaptive Histogram Equalization | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
adaptive_hist_eq = clahe.apply(gray) | |
# Gaussian Blurring | |
gaussian_blurred = cv2.GaussianBlur(adaptive_hist_eq, (5,5), 0) | |
# Noise reduction | |
denoised = cv2.medianBlur(gaussian_blurred, 3) | |
# Brightness & Contrast adjustment | |
lab = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2Lab) | |
l, a, b = cv2.split(lab) | |
cl = clahe.apply(l) | |
limg = cv2.merge((cl, a, b)) | |
enhanced_image = cv2.cvtColor(limg, cv2.COLOR_Lab2BGR) | |
# Convert back to PIL Image | |
enhanced_pil_image = Image.fromarray(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)) | |
# IMAGE AUGMENTATION | |
# For demonstration purposes, let's do a simple brightness adjustment. | |
# In practice, choose the augmentations that suit your task. | |
enhancer = ImageEnhance.Brightness(enhanced_pil_image) | |
enhanced_pil_image = enhancer.enhance(1.2) # Brighten the image by 20% | |
return enhanced_pil_image | |
print("checking pinecone Index") | |
if INDEX_NAME not in pinecone.list_indexes(): | |
# delete the current index and create the new index if it does not exist | |
for delete_index in pinecone.list_indexes(): | |
print(f"Deleting exitsing pinecone Index : {delete_index}") | |
pinecone.delete_index(delete_index) | |
print(f"Creating new pinecone Index : {INDEX_NAME}") | |
pinecone.create_index(INDEX_NAME, dimension=INDEX_DIMENSION, metric="cosine") | |
print("Connecting to Pinecone Index") | |
index = pinecone.Index(INDEX_NAME) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
app = FastAPI(title="CV Demos") | |
# define response | |
def root_route(): | |
return {"error": f"Use GET {IMAGE_SIMILARITY_PINECONE_DEMO} instead of the root route!"} | |
async def image_search_local( | |
images_to_search: List[UploadFile], query_image: UploadFile = File(...), top_k: int = 5 | |
): | |
print( | |
f"Recived images of length: {len(images_to_search)} needs to retrieve top k : {top_k} similar images as result" | |
) | |
try: | |
extension = query_image.filename.split(".")[-1] in ("jpg", "jpeg", "png") | |
search_images = [] | |
search_filenames = [] | |
print("Processing request...") | |
for image in images_to_search: | |
if image.filename.split(".")[-1] not in ("jpg", "jpeg", "png"): | |
return "Image must be jpg or png format!" | |
# read image contain | |
search_filenames.append(image.filename) | |
contents = await image.read() | |
search_images.append(Image.open(io.BytesIO(contents))) | |
print("Indexing images to search...") | |
corpus_embeddings = image_sim_model.encode( | |
search_images, convert_to_tensor=True, show_progress_bar=True | |
) | |
if not extension: | |
return "Image must be jpg or png format!" | |
# read image contain | |
contents = await query_image.read() | |
query_image = Image.open(io.BytesIO(contents)) | |
print("Indexing query image...") | |
prompt_embedding = image_sim_model.encode(query_image, convert_to_tensor=True) | |
print("Searching query image...") | |
hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=top_k) | |
# hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score']) | |
# tmp_file = f"{TMP_DIR}/tmp.png" | |
# pil_image.save(tmp_file) | |
# answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question) | |
print("Creating the result..") | |
similar_images = [] | |
print("hits ", hits) | |
for hit in hits[0]: | |
# print("Finding the image ") | |
# print("Type of images list ", type(search_images), "similar image id ", hit['corpus_id']) | |
open_cv_image = np.array(search_images[hit["corpus_id"]].convert("RGB"))[:, :, ::-1] | |
# print("cv2.imencode the image ") | |
_, encoded_img = cv2.imencode(".PNG", open_cv_image) | |
# print("base64 the image ") | |
encoded_img = base64.b64encode(encoded_img) | |
# print("Appending the image ") | |
similar_images.append( | |
{ | |
"filename": search_filenames[hit["corpus_id"]], | |
"dimensions": str(open_cv_image.shape), | |
"score": hit["score"], | |
"encoded_img": encoded_img, | |
} | |
) | |
print("Sending result..") | |
return {"similar_images": similar_images} | |
except: | |
e = sys.exc_info()[1] | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def image_search_pinecone( | |
images_to_search: Optional[List[UploadFile]] = File(None), | |
query_image: Optional[UploadFile] = File(None), | |
top_k: int = 5, | |
namespace="av_local", | |
action="query", | |
): | |
try: | |
# Function to delete all files from the database | |
print(f"Received request with images_to_search: {images_to_search} and query_image: {query_image} with action: {action}") | |
if action == "delete": | |
index = pinecone.Index(INDEX_NAME) | |
delete_response = index.delete(delete_all=True, namespace=namespace) | |
return {f"Deleted the namespace: {namespace}": delete_response} | |
elif action == "query" and query_image is not None: | |
extension = query_image.filename.split(".")[-1] in ("jpg", "jpeg", "png", "JPG", "PNG", "JPEG") | |
if not extension: | |
return "Image must be jpg or png format!" | |
# read image contain | |
contents = await query_image.read() | |
query_image = Image.open(io.BytesIO(contents)) | |
print("Indexing query image...") | |
query_image = enhance_image(query_image) | |
prompt_embedding = image_sim_model.encode(query_image, convert_to_tensor=True).tolist() | |
if INDEX_NAME not in pinecone.list_indexes(): | |
return {"similar_images": [], "status": "No index found for images"} | |
else: | |
index = pinecone.Index(INDEX_NAME) | |
query_response = index.query( | |
namespace=namespace, | |
top_k=top_k, | |
include_values=True, | |
include_metadata=True, | |
vector=prompt_embedding, | |
) | |
result_images = [d["metadata"]["file_path"] for d in query_response["matches"]] | |
print("Creating the result..") | |
similar_images = [] | |
print("Retrieved matches ", query_response["matches"]) | |
for file_path in result_images: | |
try: | |
# print("Finding the image ") | |
# print("Type of images list ", type(search_images), "similar image id ", hit['corpus_id']) | |
open_cv_image = cv2.imread(file_path) | |
# print("cv2.imencode the image ") | |
_, encoded_img = cv2.imencode(".PNG", open_cv_image) | |
# print("base64 the image ") | |
encoded_img = base64.b64encode(encoded_img) | |
# print("Appending the image ") | |
similar_images.append( | |
{ | |
"filename": file_path, | |
"dimensions": str(open_cv_image.shape), | |
"score": 0, | |
"encoded_img": encoded_img, | |
} | |
) | |
except: | |
similar_images.append( | |
{ | |
"filename": file_path, | |
"dimensions": None, | |
"score": 0, | |
"encoded_img": None, | |
} | |
) | |
print("Sending result..") | |
return {"similar_images": similar_images} | |
elif action == "index" and len(images_to_search) > 0: | |
print( | |
f"Recived images of length: {len(images_to_search)} needs to retrieve top k : {top_k} similar images as result" | |
) | |
print(f"Action indexing is executing for : {len(images_to_search)} images") | |
# if the index does not already exist, we create it | |
# check if the abstractive-question-answering index exists | |
print("checking pinecone Index") | |
if INDEX_NAME not in pinecone.list_indexes(): | |
# delete the current index and create the new index if it does not exist | |
for delete_index in pinecone.list_indexes(): | |
print(f"Deleting exitsing pinecone Index : {delete_index}") | |
pinecone.delete_index(delete_index) | |
print(f"Creating new pinecone Index : {INDEX_NAME}") | |
pinecone.create_index(INDEX_NAME, dimension=INDEX_DIMENSION, metric="cosine") | |
# instantiate connection to your Pinecone index | |
print(f"Connecting to pinecone Index : {INDEX_NAME}") | |
index = pinecone.Index(INDEX_NAME) | |
search_images = [] | |
meta_datas = [] | |
ids = [] | |
print("Processing request...") | |
for image in images_to_search: | |
if image.filename.split(".")[-1] not in ("jpg", "jpeg", "png", "JPG", "PNG", "JPEG"): | |
return "Image must be jpg or png format!" | |
# read image contain | |
contents = await image.read() | |
pil_image = Image.open(io.BytesIO(contents)) | |
tmp_file = f"{TMP_DIR}/{image.filename}" | |
pil_image.save(tmp_file) | |
meta_datas.append({"file_path": tmp_file}) | |
search_images.append(pil_image) | |
ids.append(str(uuid.uuid1()).replace("-","")) | |
print("Encoding images to vectors...") | |
corpus_embeddings = image_sim_model.encode( | |
search_images, convert_to_tensor=True, show_progress_bar=True | |
).tolist() | |
print(f"Indexing images to pinecone Index : {INDEX_NAME}") | |
index.upsert( | |
vectors=list(zip(ids, corpus_embeddings, meta_datas)), namespace=namespace | |
) | |
return {"similar_images": [], "status": "Indexing succesfull for uploaded files"} | |
else: | |
return {"similar_images": []} | |
except Exception as e: | |
e = sys.exc_info()[1] | |
print(f"exception happened {e} {str(traceback.print_exc())}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Fast API exposing YOLOv5 model") | |
parser.add_argument("--port", default=8000, type=int, help="port number") | |
# parser.add_argument('--model', nargs='+', default=['yolov5s'], help='model(s) to run, i.e. --model yolov5n yolov5s') | |
opt = parser.parse_args() | |
uvicorn.run(app, port=opt.port) | |