Ransaka commited on
Commit
c4b80ab
·
verified ·
1 Parent(s): 1f7ccbe

Update recommendation.py

Browse files
Files changed (1) hide show
  1. recommendation.py +3 -3
recommendation.py CHANGED
@@ -21,14 +21,14 @@ embeddings = torch.load("data/embeddings.bin")
21
  text_embedding_index = Indexer(embeddings)
22
  image_embedding_index = Indexer(latent_data)
23
 
24
- def get_recommendations(image, title):
25
  # title = [dataset[product_id]['title']]
26
  title_embeds = model.encode([title], normalize_embeddings=True)
27
  image = transforms.ToTensor()(image.convert("L"))
28
  image_embeds = encoder(image).detach().numpy()
29
 
30
- image_candidates = image_embedding_index.topk(image_embeds)
31
- title_candidates = text_embedding_index.topk(title_embeds)
32
  final_candidates = []
33
  final_candidates.append(list(image_candidates[0]))
34
  final_candidates.append(list(title_candidates[0]))
 
21
  text_embedding_index = Indexer(embeddings)
22
  image_embedding_index = Indexer(latent_data)
23
 
24
+ def get_recommendations(image, title, k):
25
  # title = [dataset[product_id]['title']]
26
  title_embeds = model.encode([title], normalize_embeddings=True)
27
  image = transforms.ToTensor()(image.convert("L"))
28
  image_embeds = encoder(image).detach().numpy()
29
 
30
+ image_candidates = image_embedding_index.topk(image_embeds,k=k)
31
+ title_candidates = text_embedding_index.topk(title_embeds, k=k)
32
  final_candidates = []
33
  final_candidates.append(list(image_candidates[0]))
34
  final_candidates.append(list(title_candidates[0]))