File size: 4,583 Bytes
1878377
 
 
87e17e3
 
 
1878377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e17e3
 
 
 
 
 
 
 
 
 
 
6bea3f9
 
 
 
 
 
 
 
 
0839c2d
 
 
6bea3f9
0874378
636cb49
df02a37
1878377
 
6bea3f9
1878377
 
 
df02a37
1878377
 
 
 
 
 
 
 
 
 
 
 
 
 
3b6db54
1878377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bea3f9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import zipfile

import os
import zipfile

# Define the filename
zip_filename = 'Images.zip'

# Get the current directory path
current_directory = os.getcwd()
print(f"Current directory: {current_directory}")

# Append a custom string to the current directory path (for demonstration)
custom_directory = os.path.join(current_directory, 'UnzippedContent')

# Ensure the custom directory exists
os.makedirs(custom_directory, exist_ok=True)

# Print the contents of the current directory before unzipping
print(f"Contents of current directory before unzipping: {os.listdir(current_directory)}")

# Check if the zip file exists in the current directory
zip_file_path = os.path.join(current_directory, zip_filename)
if os.path.isfile(zip_file_path):
    # Open the zip file
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        # Extract all contents of the zip file to the custom directory
        zip_ref.extractall(custom_directory)
    print(f"'{zip_filename}' has been successfully unzipped to '{custom_directory}'.")

    # Print the contents of the custom directory after unzipping
    print(f"Contents of '{custom_directory}': {os.listdir(custom_directory)}")
else:
    print(f"'{zip_filename}' not found in the current directory.")

# Print the contents of the current directory after unzipping
print(f"Contents of current directory after unzipping: {os.listdir(current_directory)}")

# Code to list the contents in the application (e.g., web interface)
def list_unzipped_contents():
    unzipped_contents = os.listdir(custom_directory)
    print(f"Unzipped contents: {unzipped_contents}")
    return unzipped_contents

# Call the function to list unzipped contents (this part should be integrated with your app's display logic)
unzipped_files = list_unzipped_contents()
# Integrate `unzipped_files` with your app's interface to display the contents


import gradio as gr
import gc
import cv2
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
from implement import *
# import config as CFG
# from main import build_loaders
# from CLIP import CLIPModel
import os




    
with gr.Blocks(css="style.css") as demo:
    def get_image_embeddings(valid_df, model_path):
        tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
        valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
        
        model = CLIPModel().to(CFG.device)
        model.load_state_dict(torch.load(model_path, map_location=CFG.device))
        model.eval()
        
        valid_image_embeddings = []
        with torch.no_grad():
            for batch in tqdm(valid_loader):
                image_features = model.image_encoder(batch["image"].to(CFG.device))
                image_embeddings = model.image_projection(image_features)
                valid_image_embeddings.append(image_embeddings)
        return model, torch.cat(valid_image_embeddings)

    _, valid_df = make_train_valid_dfs()
    model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

    def find_matches(query, n=9):
        tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
        encoded_query = tokenizer([query])
        batch = {
            key: torch.tensor(values).to(CFG.device)
            for key, values in encoded_query.items()
        }
        with torch.no_grad():
            text_features = model.text_encoder(
                input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
            )
            text_embeddings = model.text_projection(text_features)
        
        image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
        dot_similarity = text_embeddings_n @ image_embeddings_n.T
        
        _, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
        matches = [valid_df['image'].values[idx] for idx in indices[::5]]
        
        images = []
        for match in matches:
            image = cv2.imread(f"{CFG.image_path}/{match}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # images.append(image)
        
        return image
    with gr.Row():
        textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
        image = gr.Image(type="numpy")
    
    button = gr.Button("Press")
    button.click(
        fn = find_matches,
        inputs=textbox,
        outputs=image
    )
    
    # Create Gradio interface
demo.launch(share=True)