chicelli commited on
Commit
9f68e7c
·
verified ·
1 Parent(s): 0721311

Upload 21 files

Browse files
README.md CHANGED
@@ -1,11 +1,129 @@
1
  ---
2
- title: Img2art Search
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
  ---
 
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: img2art-search
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.37.2
 
 
 
6
  ---
7
+ # Image-to-Art Search 🔍
8
 
9
+ "<b>Find real artwork that looks like your images</b>"
10
+
11
+ This project fine-tunes a Vision Transformer (ViT) model, pre-trained with "google/vit-base-patch32-224-in21k" weights and fine tuned with the style of [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/), to perform image-to-art search across 81k artworks made available by [WikiArt](https://wikiart.org/).
12
+
13
+ ![horse](examples/horse.jpg)
14
+
15
+ ## Table of Contents
16
+
17
+ - [Overview](#overview)
18
+ - [Installation](#installation)
19
+ - [How it works](#how-it-works)
20
+ - [Dataset](#dataset)
21
+ - [Training](#training)
22
+ - [Inference](#inference)
23
+ - [Contributing](#contributing)
24
+ - [License](#license)
25
+
26
+ ## Overview
27
+
28
+ This project leverages the Vision Transformer (ViT) model architecture for the task of image-to-art search. By fine-tuning the pre-trained ViT model on a custom dataset derived from the Instagram account [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/), we aim to create a model capable of matching images (but not only) to corresponding artworks, being able to search for any of the images on [WikiArt](https://wikiart.org/).
29
+
30
+ ## Installation
31
+
32
+ 1. Clone the repository:
33
+ ```sh
34
+ git clone https://github.com/brunorosilva/img2art-search.git
35
+ cd img2art-search
36
+ ```
37
+
38
+ 2. Install poetry:
39
+ ```sh
40
+ pip install poetry
41
+ ```
42
+
43
+ 3. Install using poetry:
44
+ ```sh
45
+ poetry install
46
+ ```
47
+
48
+ ## How it works
49
+
50
+ ### Dataset Preparation
51
+
52
+ 1. Download images from the [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/) Instagram account.
53
+ 2. Organize the images into appropriate directories for training and validation.
54
+ 3. Get a fine tuned model
55
+ 4. Create the gallery using WikiArt
56
+
57
+ ### Training
58
+
59
+ Fine-tune the ViT model:
60
+ ```sh
61
+ make train
62
+ ```
63
+
64
+ ### Inference via Gradio
65
+
66
+ Perform image-to-art search using the fine-tuned model:
67
+ ```sh
68
+ make viz
69
+ ```
70
+
71
+ ### Recreate the wikiart gallery
72
+ ```sh
73
+ make wikiart
74
+ ```
75
+
76
+ ### Create new gallery
77
+
78
+ If you want to index new images to search, use:
79
+ ```sh
80
+ poetry run python main.py gallery --gallery_path <your_path>
81
+ ```
82
+
83
+ ## Dataset
84
+
85
+ The dataset derives from 1k images from the Instagram account [ArtButMakeItSports](https://www.instagram.com/artbutmakeitsports/). Images are downloaded and split into training, validation and test sets. Each image is paired with its corresponding artwork for training purposes, if you want this dataset just ask me stating your usage.
86
+
87
+ WikiArt is indexed using the same process, except that there's no expected result. So each artwork is mapped to itself and the model is used as a feature extractor and the gallery embeddings are saved as a numpy file (will be changed to chromadb in the future).
88
+
89
+ ## Training
90
+
91
+ The training script fine-tunes the ViT model on the prepared dataset. Key steps include:
92
+
93
+ 1. Loading the pre-trained "google/vit-base-patch32-224-in21k" weights.
94
+ 2. Preparing the dataset and data loaders.
95
+ 3. Fine-tuning the model using a custom training loop.
96
+ 4. Saving the model to the models folder
97
+
98
+ ## Interface
99
+
100
+ The recommended method to get models is to use [gradio](https://www.gradio.app/) as an interface by running `make viz`. This will open a server and you can use some image you want to search or even use your webcam to get top 4 search models.
101
+
102
+ ### Examples
103
+ Search for contextual similarity
104
+ ![field](examples/field.jpg)
105
+
106
+ Search for shapes similarity
107
+ ![basket](examples/basketball.jpg)
108
+
109
+ Search for expression similarity (yep, that's me)
110
+ ![serious_face](examples/serious_face.jpg)
111
+
112
+ Search for pose similarity
113
+ ![lawyer](examples/lawyer.jpg)
114
+
115
+ Search for an object
116
+ ![horse](examples/horse.jpg)
117
+
118
+
119
+
120
+ ## Contributing
121
+ There are three topics I'd appreciate help with:
122
+ 1. Increasing the gallery by embedding new painting datasets, the current one has 81k artworks because I just got a ready to go dataset, but the complete WikiArt catalog alone has 250k+ artworks, so I really want to up this number to a least 300k in the near future;
123
+ 2. Porting the encoding and search to a vector db, preferably chromadb;
124
+ 3. Open issues with how this could be improved, new ideas will be considered.
125
+
126
+ ## License
127
+ The source code for the site is licensed under the MIT license, which you can find in the MIT-LICENSE.txt file.
128
+
129
+ All graphical assets are licensed under the Creative Commons Attribution 3.0 Unported License.
img2art_search/__init__.py ADDED
File without changes
img2art_search/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ BASE_PATH = "data/artmakeitsports"
img2art_search/data/__init__.py ADDED
File without changes
img2art_search/data/data.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ from img2art_search.constants import BASE_PATH
6
+
7
+
8
+ def get_data_from_local() -> np.ndarray:
9
+ left_or_top_data = [
10
+ f"{BASE_PATH}/splits/left_or_top/{fn}"
11
+ for fn in os.listdir(f"{BASE_PATH}/splits/left_or_top")
12
+ ]
13
+
14
+ x = np.array(left_or_top_data)
15
+ y = np.array([ex.replace("left_or_top", "right_or_bottom") for ex in x])
16
+
17
+ data = np.array([x, y])
18
+ return data
19
+
20
+
21
+ def split_train_val_test(data: np.ndarray, test_size: float, val_size: float) -> tuple:
22
+ train_size = 1 - test_size - val_size
23
+ SPLIT = int(data.shape[1] * train_size)
24
+ TEST_SPLIT = SPLIT + int(data.shape[1] * test_size)
25
+ train = data[:, :SPLIT]
26
+ validation = data[:, SPLIT:TEST_SPLIT]
27
+ test = data[:, TEST_SPLIT:]
28
+ return train, validation, test
img2art_search/data/dataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from torch.utils.data import Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class ImageRetrievalDataset(Dataset):
8
+ def __init__(self, data: np.ndarray, transform: transforms.Compose) -> None:
9
+ self.data = data
10
+ self.transform = transform
11
+
12
+ def __len__(self) -> int:
13
+ return len(self.data[0])
14
+
15
+ def __getitem__(self, idx: int) -> tuple:
16
+ input_path, label_path = self.data.T[idx]
17
+ input_image = Image.open(input_path).convert("RGB")
18
+ label_image = Image.open(label_path).convert("RGB")
19
+
20
+ # if self.transform:
21
+ input_image = self.transform(input_image)
22
+ label_image = self.transform(label_image)
23
+
24
+ return input_image, label_image, input_path, label_path
img2art_search/data/transforms.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+ transform = transforms.Compose(
4
+ [
5
+ transforms.Resize((224, 224)),
6
+ transforms.ToTensor(),
7
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
8
+ ]
9
+ )
10
+
11
+
12
+ inversetransform = transforms.Compose(
13
+ [
14
+ transforms.Normalize(
15
+ mean=[0.0, 0.0, 0.0], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
16
+ ),
17
+ transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
18
+ ]
19
+ )
img2art_search/losses/__init__.py ADDED
File without changes
img2art_search/losses/contrastiveloss.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContrastiveLoss(torch.nn.Module):
5
+ def __init__(self, margin=1.0):
6
+ super(ContrastiveLoss, self).__init__()
7
+ self.margin = margin
8
+
9
+ def forward(self, output1, output2): # noqa
10
+ euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
11
+ loss = torch.mean(torch.pow(euclidean_distance, 2))
12
+ return loss
img2art_search/models/__init__.py ADDED
File without changes
img2art_search/models/compute_embeddings.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from io import BytesIO
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from img2art_search.data.dataset import ImageRetrievalDataset
12
+ from img2art_search.data.transforms import transform
13
+ from img2art_search.models.model import ViTImageSearchModel
14
+ from img2art_search.utils import (
15
+ get_or_create_pinecone_index,
16
+ get_pinecone_client,
17
+ inverse_transform_img,
18
+ )
19
+
20
+
21
+ def extract_embedding(image_data_batch, fine_tuned_model):
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ image_data_batch = image_data_batch.to(DEVICE)
24
+ with torch.no_grad():
25
+ embeddings = fine_tuned_model(image_data_batch).cpu().numpy()
26
+ return embeddings
27
+
28
+
29
+ def load_fine_tuned_model():
30
+ fine_tuned_model = ViTImageSearchModel()
31
+ fine_tuned_model.load_state_dict(torch.load("models/model.pth"))
32
+ fine_tuned_model.eval()
33
+ return fine_tuned_model
34
+
35
+
36
+ def create_gallery(
37
+ img_dataset: ImageRetrievalDataset,
38
+ fine_tuned_model: ViTImageSearchModel,
39
+ save: bool = True,
40
+ ) -> list:
41
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+ batch_size = 4
43
+ fine_tuned_model.to(DEVICE)
44
+ gallery_embeddings = []
45
+ gallery_dataloader = DataLoader(
46
+ img_dataset, batch_size=batch_size, num_workers=1, shuffle=False
47
+ )
48
+
49
+ pc = get_pinecone_client()
50
+ gallery_index = get_or_create_pinecone_index(pc)
51
+ try:
52
+ count = 0
53
+ for img_data, _, img_name, _ in tqdm(gallery_dataloader):
54
+ data_objects = []
55
+ batch_embedding = extract_embedding(img_data, fine_tuned_model)
56
+ gallery_embeddings.append(batch_embedding)
57
+ for idx, embedding in enumerate(batch_embedding):
58
+ image = Image.fromarray(
59
+ inverse_transform_img(img_data[idx]).numpy().astype("uint8"), "RGB"
60
+ )
61
+ buffered = BytesIO()
62
+ image.save(buffered, format="JPEG")
63
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
64
+ data_objects.append(
65
+ {
66
+ "id": str(count),
67
+ "values": embedding.tolist(),
68
+ "metadata": {
69
+ "image": img_base64,
70
+ "name": img_name[idx]
71
+ .split("/")[-1]
72
+ .replace(".jpg", "")
73
+ .replace(".jpeg", "")
74
+ .replace(".png", "")
75
+ .replace(".JPG", "")
76
+ .replace(".JPEG", "")
77
+ .replace("-", " ")
78
+ .replace("_", " - ")
79
+ .title(),
80
+ },
81
+ }
82
+ )
83
+ count += 1
84
+ gallery_index.upsert(vectors=data_objects)
85
+ except Exception as e:
86
+ print(f"Error creating gallery: {e}")
87
+
88
+ if save:
89
+ np.save("models/embeddings", gallery_embeddings)
90
+ return gallery_embeddings
91
+
92
+
93
+ def search_image(query_image_path: str, k: int = 4) -> tuple:
94
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
95
+ fine_tuned_model = load_fine_tuned_model()
96
+ fine_tuned_model.to(DEVICE)
97
+ query_embedding = extract_embedding(query_image_path, fine_tuned_model)
98
+ pc = get_pinecone_client()
99
+ index = get_or_create_pinecone_index(pc)
100
+ response = index.query(
101
+ vector=[query_embedding.tolist()[0]], top_k=k, include_metadata=True
102
+ )
103
+ distances = []
104
+ results = []
105
+ for obj in response["matches"]:
106
+ result = base64.b64decode(obj.metadata["image"])
107
+ results.append(result)
108
+ distances.append(
109
+ str(round(obj["score"], 2) * 100) + " " + str(obj.metadata["name"])
110
+ )
111
+
112
+ return results, distances
113
+
114
+
115
+ def create_gallery_embeddings(folder: str) -> None:
116
+ x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
117
+ gallery_data = np.array([x, x])
118
+ gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
119
+ fine_tuned_model = load_fine_tuned_model()
120
+ create_gallery(gallery_dataset, fine_tuned_model)
img2art_search/models/model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import ViTModel
3
+
4
+
5
+ class ViTImageSearchModel(nn.Module):
6
+ def __init__(self, pretrained_model_name="google/vit-base-patch32-224-in21k"):
7
+ super(ViTImageSearchModel, self).__init__()
8
+ self.vit = ViTModel.from_pretrained(pretrained_model_name)
9
+
10
+ def forward(self, x): # noqa
11
+ outputs = self.vit(pixel_values=x)
12
+ cls_hidden_state = outputs.last_hidden_state[:, 0, :]
13
+ return cls_hidden_state
img2art_search/models/predict.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from img2art_search.data.dataset import ImageRetrievalDataset
9
+ from img2art_search.data.transforms import transform
10
+ from img2art_search.models.compute_embeddings import search_image
11
+
12
+
13
+ def predict(img: Image.Image) -> list:
14
+ tmp_img_path = "tmp_img.png"
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ if img:
17
+ img.save(tmp_img_path)
18
+ pred_img = np.array([[tmp_img_path], [tmp_img_path]])
19
+ pred_dataset = ImageRetrievalDataset(pred_img, transform=transform)
20
+ pred_image_data = pred_dataset[0][0].unsqueeze(0).to(DEVICE)
21
+ indices, distances = search_image(pred_image_data)
22
+ results = []
23
+ for index, distance in zip(indices, distances):
24
+ buffered = BytesIO(index)
25
+ image = Image.open(buffered)
26
+ decoded_image_array = np.array(image)
27
+
28
+ results.append(
29
+ (
30
+ Image.fromarray(decoded_image_array),
31
+ str(distance),
32
+ )
33
+ )
34
+ os.remove(tmp_img_path)
35
+ return results
36
+ else:
37
+ return []
img2art_search/models/train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.optim import Adam
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+
7
+ from img2art_search.data.data import get_data_from_local, split_train_val_test
8
+ from img2art_search.data.dataset import ImageRetrievalDataset
9
+ from img2art_search.data.transforms import transform
10
+ from img2art_search.losses.contrastiveloss import ContrastiveLoss
11
+ from img2art_search.models.model import ViTImageSearchModel
12
+
13
+
14
+ def fine_tune_vit(epochs: int, batch_size: int) -> None:
15
+ data = get_data_from_local()
16
+ train_data, val_data, test_data = split_train_val_test(data, 0.2, 0.1)
17
+ np.save("models/test_data", test_data)
18
+ train_dataset = ImageRetrievalDataset(train_data, transform=transform)
19
+ val_dataset = ImageRetrievalDataset(val_data, transform=transform)
20
+
21
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
22
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
23
+
24
+ model = ViTImageSearchModel()
25
+
26
+ # logs
27
+ log_dir = "./logs/"
28
+ writer = SummaryWriter(log_dir=log_dir)
29
+
30
+ # params
31
+ criterion = ContrastiveLoss()
32
+ optimizer = Adam(model.parameters(), lr=1e-4)
33
+ epochs = epochs
34
+
35
+ for epoch in range(epochs):
36
+ model.train()
37
+ total_loss = 0
38
+
39
+ for batch_idx, batch in enumerate(train_loader):
40
+ inputs, labels = batch
41
+ optimizer.zero_grad()
42
+
43
+ input_embeddings = model(inputs)
44
+ label_embeddings = model(labels)
45
+
46
+ loss = criterion(input_embeddings, label_embeddings)
47
+
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+ total_loss += loss.item()
52
+ writer.add_scalar(
53
+ "Train Loss", loss.item(), epoch * len(train_loader) + batch_idx
54
+ )
55
+
56
+ avg_train_loss = total_loss / len(train_loader)
57
+ writer.add_scalar("Average Train Loss", avg_train_loss, epoch)
58
+
59
+ print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader)}")
60
+
61
+ model.eval()
62
+ with torch.no_grad():
63
+ val_loss = 0
64
+ for batch_idx, batch in enumerate(val_loader):
65
+ inputs, labels = batch
66
+ input_embeddings = model(inputs)
67
+ label_embeddings = model(labels)
68
+
69
+ loss = criterion(input_embeddings, label_embeddings)
70
+ val_loss += loss.item()
71
+ avg_val_loss = val_loss / len(val_loader)
72
+ writer.add_scalar("Validation Loss", avg_val_loss, epoch)
73
+ print(f"Validation Loss: {val_loss / len(val_loader)}")
74
+
75
+ torch.save(model.state_dict(), "models/model.pth")
img2art_search/utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+
4
+ import torch
5
+ from pinecone import Pinecone, ServerlessSpec
6
+
7
+ from img2art_search.data.transforms import inversetransform
8
+
9
+ pinecone_api_key = os.environ["PINECONE_API_KEY"]
10
+
11
+ def inverse_transform_img(img: torch.Tensor) -> torch.Tensor:
12
+ inv_tensor = inversetransform(img)
13
+ tensor_image = (inv_tensor * 255).byte()
14
+ return tensor_image.permute(1, 2, 0)
15
+
16
+
17
+ def get_pinecone_client() -> Pinecone:
18
+ pc = Pinecone(api_key=pinecone_api_key)
19
+ return pc
20
+
21
+
22
+ def get_or_create_pinecone_index(
23
+ pc: Pinecone, index_name: str = "img2art-search", embeddings_dim: int = 768
24
+ ) -> Any:
25
+ indexes_names = [index.name for index in pc.list_indexes()]
26
+ if index_name not in indexes_names:
27
+ pc.create_index(
28
+ name=index_name,
29
+ dimension=embeddings_dim,
30
+ metric="cosine",
31
+ spec=ServerlessSpec(cloud="aws", region="us-east-1"),
32
+ )
33
+
34
+ index = pc.Index(index_name)
35
+
36
+ return index
main.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from img2art_search.models.predict import predict
2
+ from img2art_search.models.train import fine_tune_vit
3
+ from img2art_search.models.compute_embeddings import create_gallery_embeddings
4
+ import gradio as gr
5
+ import argparse
6
+
7
+ def make_interface():
8
+ interface = gr.Interface(
9
+ fn=predict,
10
+ inputs=gr.Image(type="pil"),
11
+ outputs=gr.Gallery(label="Most similar images", height=256 * 3),
12
+ # live=True,
13
+ )
14
+ interface.launch(share=False)
15
+
16
+ def train(epochs, batch_size):
17
+ fine_tune_vit(epochs, batch_size)
18
+
19
+ def create_gallery(gallery_path):
20
+ create_gallery_embeddings(gallery_path)
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="Train or infer the ViT model for image-to-art search.")
24
+ subparsers = parser.add_subparsers(dest="command")
25
+
26
+ # Subparser for training
27
+ train_parser = subparsers.add_parser("train", help="Fine-tune the ViT model")
28
+ train_parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
29
+ train_parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
30
+
31
+ # Subparser for inference
32
+ _ = subparsers.add_parser("interface", help="Perform image-to-art search using the fine-tuned model")
33
+
34
+ create_gallery_parser = subparsers.add_parser("gallery", help="Create new gallery from a path")
35
+ create_gallery_parser.add_argument("--gallery_path", type=str, default="data/wikiart")
36
+ args = parser.parse_args()
37
+
38
+ if args.command == "train":
39
+ train(args.epochs, args.batch_size)
40
+ elif args.command == "interface":
41
+ make_interface()
42
+ elif args.command == "gallery":
43
+ create_gallery(args.gallery_path)
44
+ else:
45
+ parser.print_help()
46
+
47
+ if __name__ == "__main__":
48
+ main()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "img2art-search"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["brunorosilva <[email protected]>"]
6
+ readme = "README.md"
7
+ packages = [{include = "img2art_search"}]
8
+ package-mode = false
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "^3.10"
12
+ torch = "^2.3.1"
13
+ torchvision = "^0.18.1"
14
+ tqdm = "^4.66.4"
15
+ pandas = "^2.2.2"
16
+ scipy = "^1.14.0"
17
+ numpy = "^2.0.0"
18
+ transformers = "^4.41.2"
19
+ tensorboard = "^2.17.0"
20
+ scikit-learn = ">=0.0.0"
21
+ matplotlib = "^3.9.0"
22
+ gradio = ">=0.0.0"
23
+ opencv-python = "^4.10.0.84"
24
+ pinecone = "^6.0.1"
25
+
26
+
27
+ [tool.poetry.dev-dependencies]
28
+ black = "^24.4.2"
29
+ vulture = "^2.11"
30
+ mypy = "^1.10.1"
31
+ flake8 = "^7.1.0"
32
+ [build-system]
33
+ requires = ["poetry-core"]
34
+ build-backend = "poetry.core.masonry.api"
35
+
36
+ [[tool.mypy.overrides]]
37
+ module = ["torchvision.*", "transformers.*", "pinecone.*", "tqdm.*"]
38
+ follow_untyped_imports = true
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ tqdm
4
+ pandas
5
+ scipy
6
+ numpy
7
+ transformers
8
+ tensorboard
9
+ scikit-learn
10
+ matplotlib
11
+ gradio
setup.cfg ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 88
3
+ ignore = E122,E123,E126,E127,E128,E203,E221,E241,E731,E722,W503
4
+ exclude = tests,.git,__init__.py
5
+
6
+ # [mypy]
7
+ # ignore_missing_imports = True
8
+
9
+ [bdist_wheel]
10
+ universal=1
setup.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+
5
+ from setuptools import find_packages
6
+ from setuptools import setup
7
+
8
+
9
+ def read(filename):
10
+ filename = os.path.join(os.path.dirname(__file__), filename)
11
+ text_type = type("")
12
+ with io.open(filename, mode="r", encoding="utf-8") as fd:
13
+ return re.sub(text_type(r":[a-z]+:`~?(.*?)`"), text_type(r"``\1``"), fd.read())
14
+
15
+
16
+ requirements = [
17
+ # use environment.yml
18
+ ]
19
+
20
+
21
+ setup(
22
+ name="img2art_search",
23
+ version="0.0.1",
24
+ url="https://github.com/brunorosilva/img2art-search",
25
+ author="Bruno Chicelli",
26
+ author_email="[email protected]",
27
+ description="Short description",
28
+ long_description=read("README.rst"),
29
+ packages=find_packages(exclude=("tests",)),
30
+ entry_points={"console_scripts": ["img2art_search=img2art_search.cli:cli"]},
31
+ install_requires=requirements,
32
+ classifiers=[
33
+ "Programming Language :: Python",
34
+ "Programming Language :: Python :: 3.6",
35
+ ],
36
+ )