licmajster commited on
Commit
eefe517
·
verified ·
1 Parent(s): 88cc338

Uploaded app and requirements for embeddings app.

Browse files
Files changed (2) hide show
  1. app.py +87 -0
  2. requirements.txt +204 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import os
6
+ import numpy as np
7
+ import pickle
8
+ import gradio as gr
9
+
10
+ class ImageDataset(Dataset):
11
+ def __init__(self, image_dir, processor):
12
+ self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
13
+ self.processor = processor
14
+
15
+ def __len__(self):
16
+ return len(self.image_paths)
17
+
18
+ def __getitem__(self, idx):
19
+ image = Image.open(self.image_paths[idx])
20
+ return self.processor(images=image, return_tensors="pt")['pixel_values'][0]
21
+
22
+ def get_and_save_clip_embeddings(image_dir, output_file, batch_size=32, device='cuda'):
23
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
24
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
+
26
+ dataset = ImageDataset(image_dir, processor)
27
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
28
+
29
+ all_embeddings = []
30
+ image_paths = []
31
+
32
+ model.eval()
33
+ with torch.no_grad():
34
+ for batch_idx, batch in enumerate(dataloader):
35
+ batch = batch.to(device)
36
+ embeddings = model.get_image_features(pixel_values=batch)
37
+ all_embeddings.append(embeddings.cpu().numpy())
38
+ start_idx = batch_idx * batch_size
39
+ end_idx = start_idx + len(batch)
40
+ image_paths.extend(dataset.image_paths[start_idx:end_idx])
41
+
42
+ all_embeddings = np.concatenate(all_embeddings)
43
+
44
+ with open(output_file, 'wb') as f:
45
+ pickle.dump({'embeddings': all_embeddings, 'image_paths': image_paths}, f)
46
+
47
+ # image_dir = "dataset/"
48
+ # output_file = "image_embeddings.pkl"
49
+ # batch_size = 32
50
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
51
+
52
+ # get_and_save_clip_embeddings(image_dir, output_file, batch_size, device)
53
+
54
+
55
+ # APP
56
+
57
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
58
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
59
+
60
+ with open('image_embeddings.pkl', 'rb') as f:
61
+ f = pickle.load(f)
62
+ image_embeddings = f['embeddings']
63
+ image_names = f['image_paths']
64
+ image_paths = 'dataset'
65
+
66
+ def cosine_similarity(a, b):
67
+ a = a / np.linalg.norm(a, axis=-1, keepdims=True)
68
+ b = b / np.linalg.norm(b, axis=-1, keepdims=True)
69
+ return np.dot(a, b.T)
70
+
71
+ def find_similar_images(text):
72
+ inputs = processor(text=[text], return_tensors="pt", padding=True)
73
+ with torch.no_grad():
74
+ text_embedding = model.get_text_features(**inputs).cpu().numpy()
75
+
76
+ similarities = cosine_similarity(text_embedding, image_embeddings)
77
+ top_indices = np.argsort(similarities[0])[::-1][:4]
78
+ top_images = [image_names[i] for i in top_indices]
79
+
80
+ return top_images
81
+
82
+ text_input = gr.Textbox(label="Input text", placeholder="Enter the images description")
83
+ imgs_output = gr.Gallery(label="Top 4 most similar images")
84
+
85
+ intf = gr.Interface(fn=find_similar_images, inputs=text_input, outputs=imgs_output)
86
+
87
+ intf.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.3
4
+ aiohttp==3.10.8
5
+ aiosignal==1.3.1
6
+ annotated-types==0.7.0
7
+ anyio==4.6.0
8
+ argon2-cffi==23.1.0
9
+ argon2-cffi-bindings==21.2.0
10
+ arrow==1.3.0
11
+ asttokens==2.4.1
12
+ async-lru==2.0.4
13
+ async-timeout==4.0.3
14
+ attrs==24.2.0
15
+ babel==2.16.0
16
+ backoff==2.2.1
17
+ beautifulsoup4==4.12.3
18
+ bleach==6.1.0
19
+ boto3==1.35.31
20
+ botocore==1.35.31
21
+ cachetools==5.5.0
22
+ certifi==2024.8.30
23
+ cffi==1.17.1
24
+ charset-normalizer==3.3.2
25
+ click==8.1.7
26
+ comm==0.2.2
27
+ contourpy==1.3.0
28
+ cycler==0.12.1
29
+ debugpy==1.8.6
30
+ decorator==5.1.1
31
+ defusedxml==0.7.1
32
+ exceptiongroup==1.2.2
33
+ executing==2.1.0
34
+ -e git+https://github.com/fastai/fastai@80e032b0eb98860166f3ace7d2408ac210174b12#egg=fastai
35
+ fastapi==0.115.0
36
+ fastcore==1.7.11
37
+ fastdownload==0.0.7
38
+ fastjsonschema==2.20.0
39
+ fastprogress==1.0.3
40
+ ffmpy==0.4.0
41
+ filelock==3.16.1
42
+ fire==0.7.0
43
+ fonttools==4.54.1
44
+ fqdn==1.5.1
45
+ frozenlist==1.4.1
46
+ fsspec==2024.9.0
47
+ google-auth==2.35.0
48
+ google-auth-oauthlib==1.2.1
49
+ gradio==4.44.1
50
+ gradio_client==1.3.0
51
+ grpcio==1.66.2
52
+ h11==0.14.0
53
+ httpcore==1.0.6
54
+ httpx==0.27.2
55
+ huggingface-hub==0.25.1
56
+ idna==3.10
57
+ importlib_resources==6.4.5
58
+ ipykernel==6.26.0
59
+ ipython==8.17.2
60
+ ipywidgets==8.1.1
61
+ isoduration==20.11.0
62
+ jedi==0.19.1
63
+ Jinja2==3.1.4
64
+ jmespath==1.0.1
65
+ joblib==1.4.2
66
+ json5==0.9.25
67
+ jsonpointer==3.0.0
68
+ jsonschema==4.23.0
69
+ jsonschema-specifications==2023.12.1
70
+ jupyter-events==0.10.0
71
+ jupyter-lsp==2.2.5
72
+ jupyter_client==8.6.3
73
+ jupyter_core==5.7.2
74
+ jupyter_server==2.14.2
75
+ jupyter_server_terminals==0.5.3
76
+ jupyterlab==4.2.0
77
+ jupyterlab_pygments==0.3.0
78
+ jupyterlab_server==2.27.3
79
+ jupyterlab_widgets==3.0.13
80
+ kiwisolver==1.4.7
81
+ lightning==2.4.0
82
+ lightning-cloud==0.5.70
83
+ lightning-utilities==0.11.7
84
+ lightning_sdk==0.1.19
85
+ litdata==0.2.19
86
+ litserve==0.2.2
87
+ Markdown==3.7
88
+ markdown-it-py==3.0.0
89
+ MarkupSafe==2.1.5
90
+ matplotlib==3.8.2
91
+ matplotlib-inline==0.1.7
92
+ mdurl==0.1.2
93
+ mistune==3.0.2
94
+ mpmath==1.3.0
95
+ multidict==6.1.0
96
+ nbclient==0.10.0
97
+ nbconvert==7.16.4
98
+ nbformat==5.10.4
99
+ nest-asyncio==1.6.0
100
+ networkx==3.3
101
+ notebook_shim==0.2.4
102
+ numpy==1.26.4
103
+ nvidia-cublas-cu12==12.1.3.1
104
+ nvidia-cuda-cupti-cu12==12.1.105
105
+ nvidia-cuda-nvrtc-cu12==12.1.105
106
+ nvidia-cuda-runtime-cu12==12.1.105
107
+ nvidia-cudnn-cu12==8.9.2.26
108
+ nvidia-cufft-cu12==11.0.2.54
109
+ nvidia-curand-cu12==10.3.2.106
110
+ nvidia-cusolver-cu12==11.4.5.107
111
+ nvidia-cusparse-cu12==12.1.0.106
112
+ nvidia-nccl-cu12==2.19.3
113
+ nvidia-nvjitlink-cu12==12.6.77
114
+ nvidia-nvtx-cu12==12.1.105
115
+ oauthlib==3.2.2
116
+ orjson==3.10.7
117
+ overrides==7.7.0
118
+ packaging==24.1
119
+ pandas==2.1.4
120
+ pandocfilters==1.5.1
121
+ parso==0.8.4
122
+ pexpect==4.9.0
123
+ pillow==10.4.0
124
+ platformdirs==4.3.6
125
+ prometheus_client==0.21.0
126
+ prompt_toolkit==3.0.48
127
+ protobuf==4.23.4
128
+ psutil==6.0.0
129
+ ptyprocess==0.7.0
130
+ pure_eval==0.2.3
131
+ pyasn1==0.6.1
132
+ pyasn1_modules==0.4.1
133
+ pycparser==2.22
134
+ pydantic==2.9.2
135
+ pydantic_core==2.23.4
136
+ pydub==0.25.1
137
+ Pygments==2.18.0
138
+ PyJWT==2.9.0
139
+ pyparsing==3.1.4
140
+ python-dateutil==2.9.0.post0
141
+ python-json-logger==2.0.7
142
+ python-multipart==0.0.12
143
+ pytorch-lightning==2.4.0
144
+ pytz==2024.2
145
+ PyYAML==6.0.2
146
+ pyzmq==26.2.0
147
+ referencing==0.35.1
148
+ regex==2024.9.11
149
+ requests==2.32.3
150
+ requests-oauthlib==2.0.0
151
+ rfc3339-validator==0.1.4
152
+ rfc3986-validator==0.1.1
153
+ rich==13.9.1
154
+ rpds-py==0.20.0
155
+ rsa==4.9
156
+ ruff==0.6.9
157
+ s3transfer==0.10.2
158
+ safetensors==0.4.5
159
+ scikit-learn==1.3.2
160
+ scipy==1.11.4
161
+ semantic-version==2.10.0
162
+ Send2Trash==1.8.3
163
+ shellingham==1.5.4
164
+ simple-term-menu==1.6.4
165
+ six==1.16.0
166
+ sniffio==1.3.1
167
+ soupsieve==2.6
168
+ stack-data==0.6.3
169
+ starlette==0.38.6
170
+ sympy==1.13.3
171
+ tensorboard==2.15.1
172
+ tensorboard-data-server==0.7.2
173
+ termcolor==2.4.0
174
+ terminado==0.18.1
175
+ threadpoolctl==3.5.0
176
+ timm==1.0.9
177
+ tinycss2==1.3.0
178
+ tokenizers==0.20.1
179
+ tomli==2.0.1
180
+ tomlkit==0.12.0
181
+ torch==2.2.1+cu121
182
+ torchmetrics==1.3.1
183
+ torchsummary==1.5.1
184
+ torchvision==0.17.1+cu121
185
+ tornado==6.4.1
186
+ tqdm==4.66.5
187
+ traitlets==5.14.3
188
+ transformers==4.45.2
189
+ triton==2.2.0
190
+ typer==0.12.5
191
+ types-python-dateutil==2.9.0.20240906
192
+ typing_extensions==4.12.2
193
+ tzdata==2024.2
194
+ uri-template==1.3.0
195
+ urllib3==2.2.3
196
+ uvicorn==0.31.0
197
+ wcwidth==0.2.13
198
+ webcolors==24.8.0
199
+ webencodings==0.5.1
200
+ websocket-client==1.8.0
201
+ websockets==12.0
202
+ Werkzeug==3.0.4
203
+ widgetsnbextension==4.0.13
204
+ yarl==1.13.1