TrCLIP / app.py
yusufani's picture
Initial Release
94078d1
raw
history blame
11.1 kB
# Importing all the necessary libraries
import os
import gradio as gr
import torch
from PIL import Image
from tqdm import tqdm
from trclip.trclip import Trclip
from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize
print(f'gr version : {gr.__version__}')
import pickle
import random
# %%
model_name = 'trclip-vitl14-e10'
if not os.path.exists(model_name):
os.system(f'git clone https://huggingface.co/yusufani/{model_name} --progress')
# %%
if not os.path.exists('TrCaption-trclip-vitl14-e10'):
os.system(f'git clone https://huggingface.co/datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress')
os.chdir('TrCaption-trclip-vitl14-e10')
os.system(f'git lfs install')
os.system(f' git lfs fetch')
os.system(f' git lfs pull')
os.chdir('..')
# %%
def load_image_embeddings():
path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
bs = 100_000
embeddings = []
for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
embeddings.append(pickle.load(f))
return torch.cat(embeddings, dim=0)
def load_text_embeddings():
path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
bs = 100_000
embeddings = []
for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
embeddings.append(pickle.load(f))
return torch.cat(embeddings, dim=0)
def load_metadata():
path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl')
with open(path, 'rb') as f:
metadata = pickle.load(f)
trcap_texts = metadata['texts']
trcap_urls = metadata['image_urls']
return trcap_texts, trcap_urls
def load_spesific_tensor(index , type , bs= 100_000):
part = index // bs
idx = index % bs
with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part*bs}.pkl'), 'rb') as f:
embeddings = pickle.load(f)
return embeddings[idx]
# %%
image_embeddings = None
text_embeddings = None
#%%
trcap_texts, trcap_urls = load_metadata()
# %%
model_path = os.path.join(model_name, 'pytorch_model.bin')
trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
#%%
import psutil
print(f"First used memory {psutil.virtual_memory().used/float(1<<30):,.0f} GB" , )
# %%
def run_im(im1, use_trcap_images, text1, use_trcap_texts):
f_texts_embeddings = None
f_image_embeddings = None
global image_embeddings
global text_embeddings
ims = None
print("im2", use_trcap_images)
if use_trcap_images:
print('TRCaption images used')
# Images taken from TRCAPTION
im_paths = trcap_urls
if image_embeddings is None:
print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
text_embeddings = None
image_embeddings = load_image_embeddings()
print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
f_image_embeddings = image_embeddings
else:
# Images taken from user
im_paths = [i.name for i in im1]
ims = [Image.open(i) for i in im_paths]
if use_trcap_texts:
random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit
f_texts_embeddings = []
for i in random_indexes:
f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
f_texts_embeddings = torch.stack(f_texts_embeddings)
texts = [trcap_texts[i] for i in random_indexes]
else:
texts = [i.trim() for i in text1.split('\n')[:2] if i.trim() != '']
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text')
print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ')
print(f'im_paths = {im_paths}')
return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths,
n_figure_in_column=2,
n_images_in_figure=4, n_figure_in_row=1, save_fig=False,
show=False,
break_on_index=-1)
def run_text(im1, use_trcap_images, text1, use_trcap_texts):
f_texts_embeddings = None
f_image_embeddings = None
global image_embeddings
global text_embeddings
ims = None
if use_trcap_images:
random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit
f_image_embeddings = []
for i in random_indexes:
f_image_embeddings.append(load_spesific_tensor(i, 'image'))
f_image_embeddings = torch.stack(f_image_embeddings)
print('TRCaption images used')
# Images taken from TRCAPTION
im_paths = [trcap_urls[i] for i in random_indexes]
else:
# Images taken from user
im_paths = [i.name for i in im1[:2]]
ims = [Image.open(i) for i in im_paths]
if use_trcap_texts:
if text_embeddings is None:
print(f"Used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
image_embeddings = None
print(f"Image embd deleted used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
text_embeddings = load_text_embeddings()
print(f"Text embed used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
f_texts_embeddings = text_embeddings
texts = trcap_texts
else:
texts = [i.trim() for i in text1.split('\n') if i.trim() != '']
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, text_features=f_texts_embeddings, mode='per_image')
print(per_mode_indices)
print(per_mode_probs)
return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
n_figure_in_column=4,
n_texts_in_figure=4 if len(texts) > 4 else len(texts),
n_figure_in_row=2,
save_fig=False,
show=False,
break_on_index=-1,
)
def change_textbox(choice):
if choice == "Use Own Images":
return gr.Image.update(visible=True)
else:
return gr.Image.update(visible=False)
with gr.Blocks() as demo:
gr.HTML("""
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<svg
width="0.65em"
height="0.65em"
viewBox="0 0 115 115"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<rect width="23" height="23" fill="white"></rect>
<rect y="69" width="23" height="23" fill="white"></rect>
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="46" width="23" height="23" fill="white"></rect>
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" width="23" height="23" fill="black"></rect>
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
</svg>
<h1 style="font-weight: 900; margin-bottom: 7px;">
Trclip Demo
<a
href="https://github.com/yusufani/TrCLIP"
style="text-decoration: underline;"
target="_blank"
></a
Github Trclip:
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Trclip is Turkish port of real clip. In this space you can try your images or/and texts.
Also you can use pre calculated TrCaption embeddings.
Number of texts = 3533312
Number of images = 3070976
>
</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("Use Own Images"):
im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\nNote: ( Random 2 sample selected in text retrieval mode )")
with gr.Tabs():
with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
text_input = gr.components.Textbox(label="Text input", optional=True)
is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \nNote: ( Random 2 sample selected in image retrieval mode")
im_ret_but = gr.Button("Image Retrieval")
text_ret_but = gr.Button("Text Retrieval")
im_out = gr.components.Image()
im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
demo.launch()
# %%