Spaces:
Sleeping
Sleeping
import numpy as np | |
from sklearn.metrics.pairwise import pairwise_distances | |
from typing import List, Dict | |
from utils.config import Config | |
from PIL import Image | |
import pandas as pd | |
import tensorflow as tf | |
import io | |
import os | |
# Load the dataset (replace with the actual path to your dataset) | |
dataset_path = Config.read('app', 'dataset') | |
# Ensure the dataset exists | |
if not os.path.exists(dataset_path): | |
raise FileNotFoundError(f"The dataset file at {dataset_path} was not found.") | |
# Load the dataset | |
data = pd.read_pickle(dataset_path) | |
# Ensure the dataset has the necessary columns: 'asin', 'title', 'brand', 'medium_image_url' | |
required_columns = ['asin', 'title', 'brand', 'medium_image_url'] | |
for col in required_columns: | |
if col not in data.columns: | |
raise ValueError(f"Missing required column: {col} in the dataset") | |
# Load the pre-trained CNN features and corresponding ASINs | |
bottleneck_features_train = np.load(Config.read('app', 'cnnmodel')) | |
bottleneck_features_train = bottleneck_features_train.astype(np.float64) | |
asins = np.load(Config.read('app', 'cssasins')) | |
asins = list(asins) | |
# Helper function to extract features from the uploaded image using a pre-trained model | |
def extract_features_from_image(image_bytes): | |
image = Image.open(io.BytesIO(image_bytes)) | |
image = image.resize((224, 224)) | |
image_array = np.array(image) / 255.0 | |
image_array = np.expand_dims(image_array, axis=0) | |
# Load the VGG16 model for feature extraction | |
model = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
features = model.predict(image_array) | |
features = features.flatten() | |
return features | |
# Function to get similar products based on CNN features | |
def get_similar_products_cnn(image_features, num_results: int) -> List[Dict]: | |
pairwise_dist = pairwise_distances(bottleneck_features_train, image_features.reshape(1, -1)) | |
# Get the indices of the closest products | |
indices = np.argsort(pairwise_dist.flatten())[0:num_results] | |
results = [] | |
for i in range(len(indices)): | |
# Get the product details for each closest match | |
product_details = data[['asin', 'brand', 'title', 'medium_image_url']].loc[data['asin'] == asins[indices[i]]] | |
for indx, row in product_details.iterrows(): | |
result = { | |
'asin': row['asin'], | |
'brand': row['brand'], | |
'title': row['title'], | |
'url': row['medium_image_url'] | |
} | |
results.append(result) | |
return results | |