File size: 5,062 Bytes
b35ee4e c4a8d1c bbd199b b35ee4e bbd199b d526dbf 3717c61 bbd199b c4a8d1c bbd199b 339abc5 d526dbf bbd199b d526dbf bbd199b 3717c61 d526dbf bbd199b d526dbf bbd199b d526dbf 3717c61 d526dbf b35ee4e bbd199b b35ee4e bbd199b b35ee4e 3717c61 b35ee4e 3717c61 b35ee4e 3717c61 b35ee4e 3717c61 b35ee4e 3717c61 b35ee4e bbd199b 3717c61 bbd199b b35ee4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import io
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from PIL import Image
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_corrmap
csv.field_size_limit(sys.maxsize)
concat = lambda x: np.concatenate(x, axis=0)
# Embeddings
gdown.cached_download(
url="https://static.taesiri.com/chm-corr/embeddings.pickle",
path="./embeddings.pickle",
quiet=False,
md5="002b2a7f5c80d910b9cc740c2265f058",
)
# embeddings
# gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
# labels
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# CUB training set
gdown.cached_download(
url="https://static.taesiri.com/chm-corr/CUB_train.zip",
path="./CUB_train.zip",
quiet=False,
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)
# EXTRACT training set
torchvision.datasets.utils.extract_archive(
from_path="CUB_train.zip",
to_path="data/",
remove_finished=False,
)
# CHM Weights
gdown.cached_download(
url="https://static.taesiri.com/chm-corr/pas_psi.pt",
path="pas_psi.pt",
quiet=False,
md5="6b7b4d7bad7f89600fac340d6aa7708b",
)
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_image, searcher=searcher):
query_embedding = QueryToEmbedding(query_image)
scores, indices, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], indices[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
# CHM Prediction
kNN_results = (top1_label, result_ctr[0][1], gallery_images)
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
support = [support_files, support_labels]
chm_output = chm_classify_and_visualize(
query_image, kNN_results, support, training_folder
)
fig, chm_output_label = plot_from_reranker_corrmap(chm_output)
# Resize the output
img_buf = io.BytesIO()
fig.savefig(img_buf, format="jpg")
image = Image.open(img_buf)
width, height = image.size
new_width = width
new_height = height
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))
chm_output_labels = Counter(
[
x.split("/")[-2].replace(".", " ").replace("_", " ")
for x in chm_output["chm-nearest-neighbors-all"][:20]
]
)
return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}
blocks = gr.Blocks()
with blocks:
gr.Markdown(""" # CHM-Corr DEMO""")
gr.Markdown(
""" ### Parameters: N=50, k=20 - Using ``ImageNet Pretrained ResNet50`` features"""
)
input_image = gr.Image(type="filepath")
run_btn = gr.Button("Classify")
gr.Markdown(""" ### CHM-Corr Output Visualization """)
viz_plot = gr.Image(type="pil", label="Visualization")
with gr.Row():
with gr.Column():
gr.Markdown(""" ### CHM-Corr Prediction """)
labels = gr.Label(label="Prediction")
with gr.Column():
gr.Markdown(""" ### Examples """)
examples = gr.Examples(
examples=[
["./examples/bird.jpg"],
["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
["./examples/sample1.jpeg"],
["./examples/sample2.jpeg"],
["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
],
inputs=[input_image],
outputs=[viz_plot, labels],
fn=search,
cache_examples=False,
)
run_btn.click(
search,
inputs=[input_image],
outputs=[viz_plot, labels],
)
if __name__ == "__main__":
blocks.launch(
debug=True,
enable_queue=True,
)
|