File size: 3,629 Bytes
f02b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# -*- coding: utf-8 -*- 

import os
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from io import BytesIO
from base64 import b64encode
from tqdm.auto import tqdm
from constants import *



# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY"
# find your environment next to the api key in pinecone console
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT"



class SearchItem():
    def __init__(self, api_key=None, env=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.api_key = api_key 
        self.environment = env 
        self.pinecone_instance = self.connect_to_pinecone(self.api_key,self.environment)
        self.index = self.pinecone_instance.Index('clip')
        self.images, self.metadata = self.load_fashion_dataset()
        self.clip_model = self.initialize_clip_model(device=device)
        self.bm25 = self.initialize_bm25_encoder(self.metadata)



    def connect_to_pinecone(self, api_key, env):
        api_key = api_key or os.getenv('PINECONE_API_KEY')
        env = env or os.getenv('PINECONE_ENVIRONMENT')
        
        if not api_key or not env:
            raise ValueError("Pinecone API key and environment are required.")
        
        pinecone_instance = Pinecone(api_key=api_key, environment=env)
        return pinecone_instance
    
    def load_fashion_dataset(self):
        fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
        images = fashion["image"]
        metadata = fashion.remove_columns("image").to_pandas()
        return images, metadata
    
    def initialize_clip_model(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
        return model
    
    def initialize_bm25_encoder(self, metadata):
        bm25 = BM25Encoder()
        bm25.fit(metadata['productDisplayName'])
        return bm25
    
    @staticmethod
    def hybrid_scale(dense, sparse, alpha=0.05):
        """Hybrid vector scaling using a convex combination

        alpha * dense + (1 - alpha) * sparse

        Args:
            dense: Array of floats representing
            sparse: a dict of `indices` and `values`
            alpha: float between 0 and 1 where 0 == sparse only
                and 1 == dense only
        """
        if alpha < 0 or alpha > 1:
            raise ValueError("Alpha must be between 0 and 1")
        
        # Scale sparse and dense vectors to create hybrid search vectors
        hsparse = {
            'indices': sparse['indices'],
            'values': [v * (1 - alpha) for v in sparse['values']]
        }
        hdense = [v * alpha for v in dense]
        
        return hdense, hsparse
    

if __name__ == "__main__":


    fashion_processor = SearchItem(api_key, env)
    
    query = "blue shoes"
    # create sparse and dense vectors
    sparse = fashion_processor.bm25.encode_queries(query)
    dense = fashion_processor.clip_model.encode(query).tolist()

    hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)

    result = fashion_processor.index.query(
        top_k=5,
        vector=hdense,
        sparse_vector=hsparse,
        include_metadata=True
    )

    imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]

    print('Ok')
    # breakpoint()