Spaces:
Running
Running
Upload 21 files
Browse files- README.md +126 -8
- img2art_search/__init__.py +0 -0
- img2art_search/constants.py +1 -0
- img2art_search/data/__init__.py +0 -0
- img2art_search/data/data.py +28 -0
- img2art_search/data/dataset.py +24 -0
- img2art_search/data/transforms.py +19 -0
- img2art_search/losses/__init__.py +0 -0
- img2art_search/losses/contrastiveloss.py +12 -0
- img2art_search/models/__init__.py +0 -0
- img2art_search/models/compute_embeddings.py +120 -0
- img2art_search/models/model.py +13 -0
- img2art_search/models/predict.py +37 -0
- img2art_search/models/train.py +75 -0
- img2art_search/utils.py +36 -0
- main.py +48 -0
- poetry.lock +0 -0
- pyproject.toml +38 -0
- requirements.txt +11 -0
- setup.cfg +10 -0
- setup.py +36 -0
README.md
CHANGED
@@ -1,11 +1,129 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
license: mit
|
9 |
---
|
|
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: img2art-search
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.37.2
|
|
|
|
|
|
|
6 |
---
|
7 |
+
# Image-to-Art Search 🔍
|
8 |
|
9 |
+
"<b>Find real artwork that looks like your images</b>"
|
10 |
+
|
11 |
+
This project fine-tunes a Vision Transformer (ViT) model, pre-trained with "google/vit-base-patch32-224-in21k" weights and fine tuned with the style of [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/), to perform image-to-art search across 81k artworks made available by [WikiArt](https://wikiart.org/).
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
## Table of Contents
|
16 |
+
|
17 |
+
- [Overview](#overview)
|
18 |
+
- [Installation](#installation)
|
19 |
+
- [How it works](#how-it-works)
|
20 |
+
- [Dataset](#dataset)
|
21 |
+
- [Training](#training)
|
22 |
+
- [Inference](#inference)
|
23 |
+
- [Contributing](#contributing)
|
24 |
+
- [License](#license)
|
25 |
+
|
26 |
+
## Overview
|
27 |
+
|
28 |
+
This project leverages the Vision Transformer (ViT) model architecture for the task of image-to-art search. By fine-tuning the pre-trained ViT model on a custom dataset derived from the Instagram account [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/), we aim to create a model capable of matching images (but not only) to corresponding artworks, being able to search for any of the images on [WikiArt](https://wikiart.org/).
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
1. Clone the repository:
|
33 |
+
```sh
|
34 |
+
git clone https://github.com/brunorosilva/img2art-search.git
|
35 |
+
cd img2art-search
|
36 |
+
```
|
37 |
+
|
38 |
+
2. Install poetry:
|
39 |
+
```sh
|
40 |
+
pip install poetry
|
41 |
+
```
|
42 |
+
|
43 |
+
3. Install using poetry:
|
44 |
+
```sh
|
45 |
+
poetry install
|
46 |
+
```
|
47 |
+
|
48 |
+
## How it works
|
49 |
+
|
50 |
+
### Dataset Preparation
|
51 |
+
|
52 |
+
1. Download images from the [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/) Instagram account.
|
53 |
+
2. Organize the images into appropriate directories for training and validation.
|
54 |
+
3. Get a fine tuned model
|
55 |
+
4. Create the gallery using WikiArt
|
56 |
+
|
57 |
+
### Training
|
58 |
+
|
59 |
+
Fine-tune the ViT model:
|
60 |
+
```sh
|
61 |
+
make train
|
62 |
+
```
|
63 |
+
|
64 |
+
### Inference via Gradio
|
65 |
+
|
66 |
+
Perform image-to-art search using the fine-tuned model:
|
67 |
+
```sh
|
68 |
+
make viz
|
69 |
+
```
|
70 |
+
|
71 |
+
### Recreate the wikiart gallery
|
72 |
+
```sh
|
73 |
+
make wikiart
|
74 |
+
```
|
75 |
+
|
76 |
+
### Create new gallery
|
77 |
+
|
78 |
+
If you want to index new images to search, use:
|
79 |
+
```sh
|
80 |
+
poetry run python main.py gallery --gallery_path <your_path>
|
81 |
+
```
|
82 |
+
|
83 |
+
## Dataset
|
84 |
+
|
85 |
+
The dataset derives from 1k images from the Instagram account [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/). Images are downloaded and split into training, validation and test sets. Each image is paired with its corresponding artwork for training purposes, if you want this dataset just ask me stating your usage.
|
86 |
+
|
87 |
+
WikiArt is indexed using the same process, except that there's no expected result. So each artwork is mapped to itself and the model is used as a feature extractor and the gallery embeddings are saved as a numpy file (will be changed to chromadb in the future).
|
88 |
+
|
89 |
+
## Training
|
90 |
+
|
91 |
+
The training script fine-tunes the ViT model on the prepared dataset. Key steps include:
|
92 |
+
|
93 |
+
1. Loading the pre-trained "google/vit-base-patch32-224-in21k" weights.
|
94 |
+
2. Preparing the dataset and data loaders.
|
95 |
+
3. Fine-tuning the model using a custom training loop.
|
96 |
+
4. Saving the model to the models folder
|
97 |
+
|
98 |
+
## Interface
|
99 |
+
|
100 |
+
The recommended method to get models is to use [gradio](https://www.gradio.app/) as an interface by running `make viz`. This will open a server and you can use some image you want to search or even use your webcam to get top 4 search models.
|
101 |
+
|
102 |
+
### Examples
|
103 |
+
Search for contextual similarity
|
104 |
+

|
105 |
+
|
106 |
+
Search for shapes similarity
|
107 |
+

|
108 |
+
|
109 |
+
Search for expression similarity (yep, that's me)
|
110 |
+

|
111 |
+
|
112 |
+
Search for pose similarity
|
113 |
+

|
114 |
+
|
115 |
+
Search for an object
|
116 |
+

|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
## Contributing
|
121 |
+
There are three topics I'd appreciate help with:
|
122 |
+
1. Increasing the gallery by embedding new painting datasets, the current one has 81k artworks because I just got a ready to go dataset, but the complete WikiArt catalog alone has 250k+ artworks, so I really want to up this number to a least 300k in the near future;
|
123 |
+
2. Porting the encoding and search to a vector db, preferably chromadb;
|
124 |
+
3. Open issues with how this could be improved, new ideas will be considered.
|
125 |
+
|
126 |
+
## License
|
127 |
+
The source code for the site is licensed under the MIT license, which you can find in the MIT-LICENSE.txt file.
|
128 |
+
|
129 |
+
All graphical assets are licensed under the Creative Commons Attribution 3.0 Unported License.
|
img2art_search/__init__.py
ADDED
File without changes
|
img2art_search/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
BASE_PATH = "data/artmakeitsports"
|
img2art_search/data/__init__.py
ADDED
File without changes
|
img2art_search/data/data.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from img2art_search.constants import BASE_PATH
|
6 |
+
|
7 |
+
|
8 |
+
def get_data_from_local() -> np.ndarray:
|
9 |
+
left_or_top_data = [
|
10 |
+
f"{BASE_PATH}/splits/left_or_top/{fn}"
|
11 |
+
for fn in os.listdir(f"{BASE_PATH}/splits/left_or_top")
|
12 |
+
]
|
13 |
+
|
14 |
+
x = np.array(left_or_top_data)
|
15 |
+
y = np.array([ex.replace("left_or_top", "right_or_bottom") for ex in x])
|
16 |
+
|
17 |
+
data = np.array([x, y])
|
18 |
+
return data
|
19 |
+
|
20 |
+
|
21 |
+
def split_train_val_test(data: np.ndarray, test_size: float, val_size: float) -> tuple:
|
22 |
+
train_size = 1 - test_size - val_size
|
23 |
+
SPLIT = int(data.shape[1] * train_size)
|
24 |
+
TEST_SPLIT = SPLIT + int(data.shape[1] * test_size)
|
25 |
+
train = data[:, :SPLIT]
|
26 |
+
validation = data[:, SPLIT:TEST_SPLIT]
|
27 |
+
test = data[:, TEST_SPLIT:]
|
28 |
+
return train, validation, test
|
img2art_search/data/dataset.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class ImageRetrievalDataset(Dataset):
|
8 |
+
def __init__(self, data: np.ndarray, transform: transforms.Compose) -> None:
|
9 |
+
self.data = data
|
10 |
+
self.transform = transform
|
11 |
+
|
12 |
+
def __len__(self) -> int:
|
13 |
+
return len(self.data[0])
|
14 |
+
|
15 |
+
def __getitem__(self, idx: int) -> tuple:
|
16 |
+
input_path, label_path = self.data.T[idx]
|
17 |
+
input_image = Image.open(input_path).convert("RGB")
|
18 |
+
label_image = Image.open(label_path).convert("RGB")
|
19 |
+
|
20 |
+
# if self.transform:
|
21 |
+
input_image = self.transform(input_image)
|
22 |
+
label_image = self.transform(label_image)
|
23 |
+
|
24 |
+
return input_image, label_image, input_path, label_path
|
img2art_search/data/transforms.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
|
3 |
+
transform = transforms.Compose(
|
4 |
+
[
|
5 |
+
transforms.Resize((224, 224)),
|
6 |
+
transforms.ToTensor(),
|
7 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
8 |
+
]
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
inversetransform = transforms.Compose(
|
13 |
+
[
|
14 |
+
transforms.Normalize(
|
15 |
+
mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
|
16 |
+
),
|
17 |
+
transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
|
18 |
+
]
|
19 |
+
)
|
img2art_search/losses/__init__.py
ADDED
File without changes
|
img2art_search/losses/contrastiveloss.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class ContrastiveLoss(torch.nn.Module):
|
5 |
+
def __init__(self, margin=1.0):
|
6 |
+
super(ContrastiveLoss, self).__init__()
|
7 |
+
self.margin = margin
|
8 |
+
|
9 |
+
def forward(self, output1, output2): # noqa
|
10 |
+
euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
|
11 |
+
loss = torch.mean(torch.pow(euclidean_distance, 2))
|
12 |
+
return loss
|
img2art_search/models/__init__.py
ADDED
File without changes
|
img2art_search/models/compute_embeddings.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from img2art_search.data.dataset import ImageRetrievalDataset
|
12 |
+
from img2art_search.data.transforms import transform
|
13 |
+
from img2art_search.models.model import ViTImageSearchModel
|
14 |
+
from img2art_search.utils import (
|
15 |
+
get_or_create_pinecone_index,
|
16 |
+
get_pinecone_client,
|
17 |
+
inverse_transform_img,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def extract_embedding(image_data_batch, fine_tuned_model):
|
22 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
image_data_batch = image_data_batch.to(DEVICE)
|
24 |
+
with torch.no_grad():
|
25 |
+
embeddings = fine_tuned_model(image_data_batch).cpu().numpy()
|
26 |
+
return embeddings
|
27 |
+
|
28 |
+
|
29 |
+
def load_fine_tuned_model():
|
30 |
+
fine_tuned_model = ViTImageSearchModel()
|
31 |
+
fine_tuned_model.load_state_dict(torch.load("models/model.pth"))
|
32 |
+
fine_tuned_model.eval()
|
33 |
+
return fine_tuned_model
|
34 |
+
|
35 |
+
|
36 |
+
def create_gallery(
|
37 |
+
img_dataset: ImageRetrievalDataset,
|
38 |
+
fine_tuned_model: ViTImageSearchModel,
|
39 |
+
save: bool = True,
|
40 |
+
) -> list:
|
41 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
batch_size = 4
|
43 |
+
fine_tuned_model.to(DEVICE)
|
44 |
+
gallery_embeddings = []
|
45 |
+
gallery_dataloader = DataLoader(
|
46 |
+
img_dataset, batch_size=batch_size, num_workers=1, shuffle=False
|
47 |
+
)
|
48 |
+
|
49 |
+
pc = get_pinecone_client()
|
50 |
+
gallery_index = get_or_create_pinecone_index(pc)
|
51 |
+
try:
|
52 |
+
count = 0
|
53 |
+
for img_data, _, img_name, _ in tqdm(gallery_dataloader):
|
54 |
+
data_objects = []
|
55 |
+
batch_embedding = extract_embedding(img_data, fine_tuned_model)
|
56 |
+
gallery_embeddings.append(batch_embedding)
|
57 |
+
for idx, embedding in enumerate(batch_embedding):
|
58 |
+
image = Image.fromarray(
|
59 |
+
inverse_transform_img(img_data[idx]).numpy().astype("uint8"), "RGB"
|
60 |
+
)
|
61 |
+
buffered = BytesIO()
|
62 |
+
image.save(buffered, format="JPEG")
|
63 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
64 |
+
data_objects.append(
|
65 |
+
{
|
66 |
+
"id": str(count),
|
67 |
+
"values": embedding.tolist(),
|
68 |
+
"metadata": {
|
69 |
+
"image": img_base64,
|
70 |
+
"name": img_name[idx]
|
71 |
+
.split("/")[-1]
|
72 |
+
.replace(".jpg", "")
|
73 |
+
.replace(".jpeg", "")
|
74 |
+
.replace(".png", "")
|
75 |
+
.replace(".JPG", "")
|
76 |
+
.replace(".JPEG", "")
|
77 |
+
.replace("-", " ")
|
78 |
+
.replace("_", " - ")
|
79 |
+
.title(),
|
80 |
+
},
|
81 |
+
}
|
82 |
+
)
|
83 |
+
count += 1
|
84 |
+
gallery_index.upsert(vectors=data_objects)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error creating gallery: {e}")
|
87 |
+
|
88 |
+
if save:
|
89 |
+
np.save("models/embeddings", gallery_embeddings)
|
90 |
+
return gallery_embeddings
|
91 |
+
|
92 |
+
|
93 |
+
def search_image(query_image_path: str, k: int = 4) -> tuple:
|
94 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
95 |
+
fine_tuned_model = load_fine_tuned_model()
|
96 |
+
fine_tuned_model.to(DEVICE)
|
97 |
+
query_embedding = extract_embedding(query_image_path, fine_tuned_model)
|
98 |
+
pc = get_pinecone_client()
|
99 |
+
index = get_or_create_pinecone_index(pc)
|
100 |
+
response = index.query(
|
101 |
+
vector=[query_embedding.tolist()[0]], top_k=k, include_metadata=True
|
102 |
+
)
|
103 |
+
distances = []
|
104 |
+
results = []
|
105 |
+
for obj in response["matches"]:
|
106 |
+
result = base64.b64decode(obj.metadata["image"])
|
107 |
+
results.append(result)
|
108 |
+
distances.append(
|
109 |
+
str(round(obj["score"], 2) * 100) + " " + str(obj.metadata["name"])
|
110 |
+
)
|
111 |
+
|
112 |
+
return results, distances
|
113 |
+
|
114 |
+
|
115 |
+
def create_gallery_embeddings(folder: str) -> None:
|
116 |
+
x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
|
117 |
+
gallery_data = np.array([x, x])
|
118 |
+
gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
|
119 |
+
fine_tuned_model = load_fine_tuned_model()
|
120 |
+
create_gallery(gallery_dataset, fine_tuned_model)
|
img2art_search/models/model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import ViTModel
|
3 |
+
|
4 |
+
|
5 |
+
class ViTImageSearchModel(nn.Module):
|
6 |
+
def __init__(self, pretrained_model_name="google/vit-base-patch32-224-in21k"):
|
7 |
+
super(ViTImageSearchModel, self).__init__()
|
8 |
+
self.vit = ViTModel.from_pretrained(pretrained_model_name)
|
9 |
+
|
10 |
+
def forward(self, x): # noqa
|
11 |
+
outputs = self.vit(pixel_values=x)
|
12 |
+
cls_hidden_state = outputs.last_hidden_state[:, 0, :]
|
13 |
+
return cls_hidden_state
|
img2art_search/models/predict.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from img2art_search.data.dataset import ImageRetrievalDataset
|
9 |
+
from img2art_search.data.transforms import transform
|
10 |
+
from img2art_search.models.compute_embeddings import search_image
|
11 |
+
|
12 |
+
|
13 |
+
def predict(img: Image.Image) -> list:
|
14 |
+
tmp_img_path = "tmp_img.png"
|
15 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
if img:
|
17 |
+
img.save(tmp_img_path)
|
18 |
+
pred_img = np.array([[tmp_img_path], [tmp_img_path]])
|
19 |
+
pred_dataset = ImageRetrievalDataset(pred_img, transform=transform)
|
20 |
+
pred_image_data = pred_dataset[0][0].unsqueeze(0).to(DEVICE)
|
21 |
+
indices, distances = search_image(pred_image_data)
|
22 |
+
results = []
|
23 |
+
for index, distance in zip(indices, distances):
|
24 |
+
buffered = BytesIO(index)
|
25 |
+
image = Image.open(buffered)
|
26 |
+
decoded_image_array = np.array(image)
|
27 |
+
|
28 |
+
results.append(
|
29 |
+
(
|
30 |
+
Image.fromarray(decoded_image_array),
|
31 |
+
str(distance),
|
32 |
+
)
|
33 |
+
)
|
34 |
+
os.remove(tmp_img_path)
|
35 |
+
return results
|
36 |
+
else:
|
37 |
+
return []
|
img2art_search/models/train.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.optim import Adam
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.utils.tensorboard import SummaryWriter
|
6 |
+
|
7 |
+
from img2art_search.data.data import get_data_from_local, split_train_val_test
|
8 |
+
from img2art_search.data.dataset import ImageRetrievalDataset
|
9 |
+
from img2art_search.data.transforms import transform
|
10 |
+
from img2art_search.losses.contrastiveloss import ContrastiveLoss
|
11 |
+
from img2art_search.models.model import ViTImageSearchModel
|
12 |
+
|
13 |
+
|
14 |
+
def fine_tune_vit(epochs: int, batch_size: int) -> None:
|
15 |
+
data = get_data_from_local()
|
16 |
+
train_data, val_data, test_data = split_train_val_test(data, 0.2, 0.1)
|
17 |
+
np.save("models/test_data", test_data)
|
18 |
+
train_dataset = ImageRetrievalDataset(train_data, transform=transform)
|
19 |
+
val_dataset = ImageRetrievalDataset(val_data, transform=transform)
|
20 |
+
|
21 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
22 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
23 |
+
|
24 |
+
model = ViTImageSearchModel()
|
25 |
+
|
26 |
+
# logs
|
27 |
+
log_dir = "./logs/"
|
28 |
+
writer = SummaryWriter(log_dir=log_dir)
|
29 |
+
|
30 |
+
# params
|
31 |
+
criterion = ContrastiveLoss()
|
32 |
+
optimizer = Adam(model.parameters(), lr=1e-4)
|
33 |
+
epochs = epochs
|
34 |
+
|
35 |
+
for epoch in range(epochs):
|
36 |
+
model.train()
|
37 |
+
total_loss = 0
|
38 |
+
|
39 |
+
for batch_idx, batch in enumerate(train_loader):
|
40 |
+
inputs, labels = batch
|
41 |
+
optimizer.zero_grad()
|
42 |
+
|
43 |
+
input_embeddings = model(inputs)
|
44 |
+
label_embeddings = model(labels)
|
45 |
+
|
46 |
+
loss = criterion(input_embeddings, label_embeddings)
|
47 |
+
|
48 |
+
loss.backward()
|
49 |
+
optimizer.step()
|
50 |
+
|
51 |
+
total_loss += loss.item()
|
52 |
+
writer.add_scalar(
|
53 |
+
"Train Loss", loss.item(), epoch * len(train_loader) + batch_idx
|
54 |
+
)
|
55 |
+
|
56 |
+
avg_train_loss = total_loss / len(train_loader)
|
57 |
+
writer.add_scalar("Average Train Loss", avg_train_loss, epoch)
|
58 |
+
|
59 |
+
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader)}")
|
60 |
+
|
61 |
+
model.eval()
|
62 |
+
with torch.no_grad():
|
63 |
+
val_loss = 0
|
64 |
+
for batch_idx, batch in enumerate(val_loader):
|
65 |
+
inputs, labels = batch
|
66 |
+
input_embeddings = model(inputs)
|
67 |
+
label_embeddings = model(labels)
|
68 |
+
|
69 |
+
loss = criterion(input_embeddings, label_embeddings)
|
70 |
+
val_loss += loss.item()
|
71 |
+
avg_val_loss = val_loss / len(val_loader)
|
72 |
+
writer.add_scalar("Validation Loss", avg_val_loss, epoch)
|
73 |
+
print(f"Validation Loss: {val_loss / len(val_loader)}")
|
74 |
+
|
75 |
+
torch.save(model.state_dict(), "models/model.pth")
|
img2art_search/utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from pinecone import Pinecone, ServerlessSpec
|
6 |
+
|
7 |
+
from img2art_search.data.transforms import inversetransform
|
8 |
+
|
9 |
+
pinecone_api_key = os.environ["PINECONE_API_KEY"]
|
10 |
+
|
11 |
+
def inverse_transform_img(img: torch.Tensor) -> torch.Tensor:
|
12 |
+
inv_tensor = inversetransform(img)
|
13 |
+
tensor_image = (inv_tensor * 255).byte()
|
14 |
+
return tensor_image.permute(1, 2, 0)
|
15 |
+
|
16 |
+
|
17 |
+
def get_pinecone_client() -> Pinecone:
|
18 |
+
pc = Pinecone(api_key=pinecone_api_key)
|
19 |
+
return pc
|
20 |
+
|
21 |
+
|
22 |
+
def get_or_create_pinecone_index(
|
23 |
+
pc: Pinecone, index_name: str = "img2art-search", embeddings_dim: int = 768
|
24 |
+
) -> Any:
|
25 |
+
indexes_names = [index.name for index in pc.list_indexes()]
|
26 |
+
if index_name not in indexes_names:
|
27 |
+
pc.create_index(
|
28 |
+
name=index_name,
|
29 |
+
dimension=embeddings_dim,
|
30 |
+
metric="cosine",
|
31 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
32 |
+
)
|
33 |
+
|
34 |
+
index = pc.Index(index_name)
|
35 |
+
|
36 |
+
return index
|
main.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from img2art_search.models.predict import predict
|
2 |
+
from img2art_search.models.train import fine_tune_vit
|
3 |
+
from img2art_search.models.compute_embeddings import create_gallery_embeddings
|
4 |
+
import gradio as gr
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
def make_interface():
|
8 |
+
interface = gr.Interface(
|
9 |
+
fn=predict,
|
10 |
+
inputs=gr.Image(type="pil"),
|
11 |
+
outputs=gr.Gallery(label="Most similar images", height=256 * 3),
|
12 |
+
# live=True,
|
13 |
+
)
|
14 |
+
interface.launch(share=False)
|
15 |
+
|
16 |
+
def train(epochs, batch_size):
|
17 |
+
fine_tune_vit(epochs, batch_size)
|
18 |
+
|
19 |
+
def create_gallery(gallery_path):
|
20 |
+
create_gallery_embeddings(gallery_path)
|
21 |
+
|
22 |
+
def main():
|
23 |
+
parser = argparse.ArgumentParser(description="Train or infer the ViT model for image-to-art search.")
|
24 |
+
subparsers = parser.add_subparsers(dest="command")
|
25 |
+
|
26 |
+
# Subparser for training
|
27 |
+
train_parser = subparsers.add_parser("train", help="Fine-tune the ViT model")
|
28 |
+
train_parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
|
29 |
+
train_parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
|
30 |
+
|
31 |
+
# Subparser for inference
|
32 |
+
_ = subparsers.add_parser("interface", help="Perform image-to-art search using the fine-tuned model")
|
33 |
+
|
34 |
+
create_gallery_parser = subparsers.add_parser("gallery", help="Create new gallery from a path")
|
35 |
+
create_gallery_parser.add_argument("--gallery_path", type=str, default="data/wikiart")
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
if args.command == "train":
|
39 |
+
train(args.epochs, args.batch_size)
|
40 |
+
elif args.command == "interface":
|
41 |
+
make_interface()
|
42 |
+
elif args.command == "gallery":
|
43 |
+
create_gallery(args.gallery_path)
|
44 |
+
else:
|
45 |
+
parser.print_help()
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "img2art-search"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["brunorosilva <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
packages = [{include = "img2art_search"}]
|
8 |
+
package-mode = false
|
9 |
+
|
10 |
+
[tool.poetry.dependencies]
|
11 |
+
python = "^3.10"
|
12 |
+
torch = "^2.3.1"
|
13 |
+
torchvision = "^0.18.1"
|
14 |
+
tqdm = "^4.66.4"
|
15 |
+
pandas = "^2.2.2"
|
16 |
+
scipy = "^1.14.0"
|
17 |
+
numpy = "^2.0.0"
|
18 |
+
transformers = "^4.41.2"
|
19 |
+
tensorboard = "^2.17.0"
|
20 |
+
scikit-learn = ">=0.0.0"
|
21 |
+
matplotlib = "^3.9.0"
|
22 |
+
gradio = ">=0.0.0"
|
23 |
+
opencv-python = "^4.10.0.84"
|
24 |
+
pinecone = "^6.0.1"
|
25 |
+
|
26 |
+
|
27 |
+
[tool.poetry.dev-dependencies]
|
28 |
+
black = "^24.4.2"
|
29 |
+
vulture = "^2.11"
|
30 |
+
mypy = "^1.10.1"
|
31 |
+
flake8 = "^7.1.0"
|
32 |
+
[build-system]
|
33 |
+
requires = ["poetry-core"]
|
34 |
+
build-backend = "poetry.core.masonry.api"
|
35 |
+
|
36 |
+
[[tool.mypy.overrides]]
|
37 |
+
module = ["torchvision.*", "transformers.*", "pinecone.*", "tqdm.*"]
|
38 |
+
follow_untyped_imports = true
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
tqdm
|
4 |
+
pandas
|
5 |
+
scipy
|
6 |
+
numpy
|
7 |
+
transformers
|
8 |
+
tensorboard
|
9 |
+
scikit-learn
|
10 |
+
matplotlib
|
11 |
+
gradio
|
setup.cfg
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 88
|
3 |
+
ignore = E122,E123,E126,E127,E128,E203,E221,E241,E731,E722,W503
|
4 |
+
exclude = tests,.git,__init__.py
|
5 |
+
|
6 |
+
# [mypy]
|
7 |
+
# ignore_missing_imports = True
|
8 |
+
|
9 |
+
[bdist_wheel]
|
10 |
+
universal=1
|
setup.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from setuptools import find_packages
|
6 |
+
from setuptools import setup
|
7 |
+
|
8 |
+
|
9 |
+
def read(filename):
|
10 |
+
filename = os.path.join(os.path.dirname(__file__), filename)
|
11 |
+
text_type = type("")
|
12 |
+
with io.open(filename, mode="r", encoding="utf-8") as fd:
|
13 |
+
return re.sub(text_type(r":[a-z]+:`~?(.*?)`"), text_type(r"``\1``"), fd.read())
|
14 |
+
|
15 |
+
|
16 |
+
requirements = [
|
17 |
+
# use environment.yml
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
setup(
|
22 |
+
name="img2art_search",
|
23 |
+
version="0.0.1",
|
24 |
+
url="https://github.com/brunorosilva/img2art-search",
|
25 |
+
author="Bruno Chicelli",
|
26 |
+
author_email="[email protected]",
|
27 |
+
description="Short description",
|
28 |
+
long_description=read("README.rst"),
|
29 |
+
packages=find_packages(exclude=("tests",)),
|
30 |
+
entry_points={"console_scripts": ["img2art_search=img2art_search.cli:cli"]},
|
31 |
+
install_requires=requirements,
|
32 |
+
classifiers=[
|
33 |
+
"Programming Language :: Python",
|
34 |
+
"Programming Language :: Python :: 3.6",
|
35 |
+
],
|
36 |
+
)
|