Spaces:
Runtime error
Runtime error
Update recommendation.py
Browse files- recommendation.py +3 -3
recommendation.py
CHANGED
@@ -21,14 +21,14 @@ embeddings = torch.load("data/embeddings.bin")
|
|
21 |
text_embedding_index = Indexer(embeddings)
|
22 |
image_embedding_index = Indexer(latent_data)
|
23 |
|
24 |
-
def get_recommendations(image, title):
|
25 |
# title = [dataset[product_id]['title']]
|
26 |
title_embeds = model.encode([title], normalize_embeddings=True)
|
27 |
image = transforms.ToTensor()(image.convert("L"))
|
28 |
image_embeds = encoder(image).detach().numpy()
|
29 |
|
30 |
-
image_candidates = image_embedding_index.topk(image_embeds)
|
31 |
-
title_candidates = text_embedding_index.topk(title_embeds)
|
32 |
final_candidates = []
|
33 |
final_candidates.append(list(image_candidates[0]))
|
34 |
final_candidates.append(list(title_candidates[0]))
|
|
|
21 |
text_embedding_index = Indexer(embeddings)
|
22 |
image_embedding_index = Indexer(latent_data)
|
23 |
|
24 |
+
def get_recommendations(image, title, k):
|
25 |
# title = [dataset[product_id]['title']]
|
26 |
title_embeds = model.encode([title], normalize_embeddings=True)
|
27 |
image = transforms.ToTensor()(image.convert("L"))
|
28 |
image_embeds = encoder(image).detach().numpy()
|
29 |
|
30 |
+
image_candidates = image_embedding_index.topk(image_embeds,k=k)
|
31 |
+
title_candidates = text_embedding_index.topk(title_embeds, k=k)
|
32 |
final_candidates = []
|
33 |
final_candidates.append(list(image_candidates[0]))
|
34 |
final_candidates.append(list(title_candidates[0]))
|