similarity / similarity.py
MarioPrzBasto's picture
Update similarity.py (#6)
f87d9f1 verified
raw
history blame
7.39 kB
import base64
import cv2
import numpy as np
import requests
import logging
from typing import List
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.metrics.pairwise import cosine_similarity
from skimage.metrics import structural_similarity as ssim
from models import RequestModel, ResponseModel
from PIL import Image
from io import BytesIO
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
mobilenet = MobileNetV2(weights="imagenet", include_top=False, pooling='avg')
def preprocess_image_for_mobilenet(image):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
return image
def mobilenet_sim(img1, img2, img1AssetCode, img2AssetCode):
try:
img1_proc = preprocess_image_for_mobilenet(img1)
img2_proc = preprocess_image_for_mobilenet(img2)
feat1 = mobilenet.predict(img1_proc, verbose=0)
feat2 = mobilenet.predict(img2_proc, verbose=0)
sim = cosine_similarity(feat1, feat2)[0][0]
sim_score = (sim + 1) * 50
print(f"MobileNet similarity score from {img1AssetCode} and {img2AssetCode} is {sim_score}")
return float(sim_score)
except Exception as e:
logging.error("Erro ao calcular similaridade com MobileNet", exc_info=True)
return 0
def orb_sim(img1, img2, img1AssetCode, img2AssetCode):
score = 0
try:
orb = cv2.ORB_create()
kp_a, desc_a = orb.detectAndCompute(img1, None)
kp_b, desc_b = orb.detectAndCompute(img2, None)
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = bf.match(desc_a, desc_b)
similar_regions = [i for i in matches if i.distance < 20]
if len(matches) > 0:
score = (len(similar_regions) / len(matches)) * 100
if (score > 0):
logging.info(f"Orb score from {img1AssetCode} and {img2AssetCode} is {score}")
except Exception as e:
logging.error("Erro ao verificar similaridade ORB", exc_info=True)
return 1 if 0 < score < 1 else score
def ssim_sim(img1, img2):
s, _ = ssim(img1, img2, full=True)
return (s + 1) * 50
def load_image(source, assetCode, contentType=None, ffmpeg_path='ffmpeg', frame_time=1):
Image.MAX_IMAGE_PIXELS = None
def extract_frame_from_video(video_path_or_url, time_sec):
print(f"[INFO] A extrair frame do vídeo: {video_path_or_url} no segundo {time_sec}")
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_frame:
frame_path = temp_frame.name
command = [
ffmpeg_path,
"-ss", str(time_sec),
"-i", video_path_or_url,
"-frames:v", "1",
"-q:v", "2",
"-y",
frame_path
]
print(f"[DEBUG] Comando ffmpeg: {' '.join(command)}")
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
print(f"[ERRO] ffmpeg falhou com código {result.returncode}")
print(f"[ERRO] stderr: {result.stderr.decode('utf-8')}")
raise RuntimeError("Erro ao extrair frame com ffmpeg.")
if not os.path.exists(frame_path):
print("[ERRO] Frame não criado. Verifica se o caminho do vídeo está correto e acessível.")
raise ValueError("Frame não encontrado após execução do ffmpeg.")
frame = cv2.imread(frame_path, cv2.IMREAD_GRAYSCALE)
os.remove(frame_path)
if frame is None:
print("[ERRO] Falha ao ler frame extraído com OpenCV.")
raise ValueError("Erro ao carregar frame extraído.")
print(f"[SUCESSO] Frame extraído com sucesso de {video_path_or_url}")
return frame
try:
if source.startswith('http'):
print(f"[INFO] Content-Type de {assetCode} é {contentType}")
if contentType and contentType.startswith('video'):
return extract_frame_from_video(source, frame_time)
print(f"[INFO] A carregar imagem {assetCode} a partir de URL")
response = requests.get(source)
img = np.asarray(bytearray(response.content), dtype=np.uint8)
img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
return img
else:
print(f"[INFO] A tentar carregar base64 de {assetCode} como imagem ou vídeo.")
try:
img_bytes = base64.b64decode(source)
if contentType and contentType.startswith('image'):
print(f"[INFO] Base64 de {assetCode} identificado como imagem")
img = Image.open(BytesIO(img_bytes))
img = np.array(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return img
else:
print(f"[INFO] Base64 de {assetCode} identificado como vídeo")
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
temp_video.write(img_bytes)
temp_video_path = temp_video.name
frame = extract_frame_from_video(temp_video_path, frame_time)
os.remove(temp_video_path)
return frame
except Exception as e:
print(f"[ERRO] Falha ao processar base64 de {assetCode}: {e}")
raise
except Exception as e:
print(f"[ERRO] Falha ao carregar imagem para {assetCode}: {e}")
return None
def check_similarity(images: List[RequestModel]):
logging.info(f"Checking similarity for main source with resource id {images[0].originId}")
original_image = load_image_url(images[0].source)
original_image_shape = original_image.shape
results = []
for i in range(1, len(images)):
try:
image = load_image_url(images[i].source)
image = cv2.resize(image, original_image_shape[::-1])
similarity_score = ssim_sim(original_image, image)
similarity_orb_score = orb_sim(original_image, image, images[0].assetCode, images[i].assetCode)
similarity_mobilenet_score = mobilenet_sim(original_image, image, images[0].assetCode, images[i].assetCode)
except Exception as e:
logging.error(f"Error loading image for resource id {images[i].originId} : {e}")
similarity_score = 0
similarity_orb_score = 0
similarity_mobilenet_score = 0
response = ResponseModel(originId=images[i].originId, source=images[i].source, sequence=images[i].sequence,
assetCode=images[i].assetCode, similarity=similarity_score, similarityOrb=similarity_orb_score, similarityMobileNet=similarity_mobilenet_score)
results.append(response)
return results