File size: 3,323 Bytes
6bea3f9 0839c2d 6bea3f9 636cb49 0874378 636cb49 df02a37 636cb49 6bea3f9 df02a37 6bea3f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import gc
import cv2
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
from implement import *
# import config as CFG
# from main import build_loaders
# from CLIP import CLIPModel
import os
import zipfile
# Define the filename
zip_filename = 'Images.zip'
import os
import zipfile
with gr.Blocks(css="style.css") as demo:
def get_image_embeddings(valid_df, model_path):
# Define the filename
zip_filename = 'Images.zip'
# Check if the file exists
if os.path.isfile(zip_filename):
# Open the zip file
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
# Extract all contents of the zip file to the current directory
zip_ref.extractall()
print(f"'{zip_filename}' has been successfully unzipped.")
else:
print(f"'{zip_filename}' not found in the current directory.")
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
return model, torch.cat(valid_image_embeddings)
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
def find_matches(query, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = model.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
_, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [valid_df['image'].values[idx] for idx in indices[::5]]
images = []
for match in matches:
image = cv2.imread(f"{CFG.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# images.append(image)
return image
with gr.Row():
textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
image = gr.Image(type="numpy")
button = gr.Button("Press")
button.click(
fn = find_matches,
inputs=textbox,
outputs=image
)
# Create Gradio interface
demo.launch(share=True)
|