|
import tensorflow as tf |
|
import os |
|
|
|
|
|
class RecommendationEngine(): |
|
def __init__(self): |
|
|
|
self.model_product_features = [ |
|
'product_id', 'product_name', 'category_name', |
|
'merchant_name', 'merchant_state', 'merchant_region', |
|
'free_shipping', 'is_sold_out', 'editor_pick', 'on_sale', |
|
'sales_last_week', 'sales_last_month', 'sales_last_year', |
|
'price_in_cents', 'reviews' |
|
] |
|
|
|
index_path = os.path.join("model", "retrieval_index") |
|
model_path = os.path.join("model", "ranking_model") |
|
self.index = tf.keras.models.load_model(index_path) |
|
self.model = tf.keras.models.load_model(model_path) |
|
|
|
products_path = os.path.join("products") |
|
self.products = tf.data.Dataset.load(products_path) |
|
|
|
def get_recommendations(self, raw_query: dict): |
|
self.query_input = { |
|
'user_id': tf.convert_to_tensor(raw_query['user_id'], dtype=tf.string), |
|
'channel': tf.convert_to_tensor(raw_query['channel'], dtype=tf.string), |
|
'device_type': tf.convert_to_tensor(raw_query['device_type'], dtype=tf.string), |
|
'query_text': tf.convert_to_tensor(raw_query['query_text'], dtype=tf.string), |
|
'time': tf.convert_to_tensor(raw_query['time'], dtype=tf.string), |
|
} |
|
|
|
|
|
_, self.top_rec = self.index({k: [v] for k, v in self.query_input.items()}) |
|
|
|
|
|
filtered_recs = self.products.filter(self.filter_by_id) |
|
|
|
query_added_recs = filtered_recs.map(lambda x: {**self.query_input, **x}) |
|
|
|
|
|
score_added_recs = query_added_recs.batch(8).map(self.get_score).unbatch() |
|
|
|
|
|
recs = score_added_recs.map(self.desired_output) |
|
|
|
|
|
ordered_recs = self.order_by_score(recs) |
|
|
|
|
|
self.recommendations = list(map(self.decode_values, ordered_recs)) |
|
|
|
def filter_by_id(self, item): |
|
return tf.reduce_any(tf.equal(item['product_id'], self.top_rec[0])) |
|
|
|
def get_score(self, item): |
|
|
|
input_data = {k: v for k, v in item.items() if k in self.model_product_features + list(self.query_input.keys())} |
|
_, _, score = self.model(input_data) |
|
item['score'] = score |
|
return item |
|
|
|
def desired_output(self, item): |
|
return { |
|
'Score': item['score'], |
|
'Product Name': item['product_name'], |
|
'Category': item['category_name'], |
|
'Price (in cents)': item['price_in_cents'], |
|
'Reviews': item['reviews'], |
|
'Merchant': item['merchant_name'], |
|
'City': item['merchant_city'], |
|
'State': item['merchant_state'], |
|
'Region': item['merchant_region'], |
|
'Free Shipping': item['free_shipping'], |
|
'Sold Out': item['is_sold_out'], |
|
'Editor\'s Pick': item['editor_pick'], |
|
'On Sale': item['on_sale'] |
|
} |
|
|
|
|
|
def order_by_score(self, recs): |
|
rec_list = list(recs.as_numpy_iterator()) |
|
|
|
|
|
return sorted(rec_list, key=lambda x: x['Score'], reverse=True) |
|
|
|
def decode_values(self, item): |
|
for key, value in item.items(): |
|
if isinstance(value, bytes): |
|
item[key] = value.decode('utf-8') |
|
if key == 'Score': |
|
item[key] = value[0] |
|
return item |