File size: 8,720 Bytes
b239c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436e280
b239c75
 
 
436e280
b239c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436e280
b239c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, pipeline
from sentence_transformers import SentenceTransformer, util
import pickle
import numpy as np
from PIL import Image, ImageEnhance
import os
import io
import concurrent.futures
import warnings

warnings.filterwarnings("ignore")

class CLIPModelHandler:
    def __init__(self, model_name):
        self.model_name = model_name
        self.img_names, self.img_emb = self.load_precomputed_embeddings()

    def load_precomputed_embeddings(self):
        emb_filename = 'unsplash-25k-photos-embeddings.pkl'
        with open(emb_filename, 'rb') as fIn:
            img_names, img_emb = pickle.load(fIn)
        return img_names, img_emb

    def search_text(self, query, top_k=1):
        model = CLIPModel.from_pretrained(self.model_name)
        processor = CLIPProcessor.from_pretrained(self.model_name)
        tokenizer = CLIPTokenizer.from_pretrained(self.model_name)
        
        inputs = tokenizer([query], padding=True, return_tensors="pt")
        query_emb = model.get_text_features(**inputs)
        hits = util.semantic_search(query_emb, self.img_emb, top_k=top_k)[0]

        images = [Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) for hit in hits]
        return images

    def search_image(self, image_path, top_k=1):
        model = CLIPModel.from_pretrained(self.model_name)
        processor = CLIPProcessor.from_pretrained(self.model_name)
        
        # Load and preprocess the image
        image = Image.open(image_path)
        inputs = processor(images=image, return_tensors="pt")

        # Get the image features
        outputs = model(**inputs)
        image_emb = outputs.logits_per_image

        # Perform semantic search
        hits = util.semantic_search(image_emb, self.img_emb, top_k=top_k)[0]

        # Retrieve and return the relevant images
        result_images = []
        for hit in hits:
            img = Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']]))
            result_images.append(img)

        return result_images

class BLIPImageCaptioning:
    def __init__(self, blip_model_name):
        self.blip_model_name = blip_model_name

    def preprocess_image(self, image):
        if isinstance(image, str):
            return Image.open(image).convert('RGB')
        elif isinstance(image, np.ndarray):
            return Image.fromarray(np.uint8(image)).convert('RGB')
        else:
            raise ValueError("Invalid input type for image. Supported types: str (file path) or np.ndarray.")

    def generate_caption(self, image):
        try:
            model = BlipForConditionalGeneration.from_pretrained(self.blip_model_name)
            processor = BlipProcessor.from_pretrained(self.blip_model_name)
            
            raw_image = self.preprocess_image(image)
            inputs = processor(raw_image, return_tensors="pt")
            out = model.generate(**inputs)
            unconditional_caption = processor.decode(out[0], skip_special_tokens=True)

            return unconditional_caption
        except Exception as e:
            return {"error": str(e)}

    def generate_captions_parallel(self, images):
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = list(executor.map(self.generate_caption, images))

        return results


# Initialize the CLIP model handler
clip_handler = CLIPModelHandler("openai/clip-vit-base-patch32")

# Initialize the zero-shot image classification pipeline
clip_classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")

# Load BLIP model directly
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model_name = "Salesforce/blip-image-captioning-base"

# Function for text-to-image search
def text_to_image_interface(query, top_k):
    try:
        # Perform text-to-image search
        result_images = clip_handler.search_text(query, top_k)

        # Resize images before displaying
        result_images_resized = [image.resize((224, 224)) for image in result_images]

        # Display more information about the results
        result_info = [{"Image Name": os.path.basename(img_path)} for img_path in clip_handler.img_names]

        return result_images_resized, result_info
    except Exception as e:
        return gr.Error(f"Error in text-to-image search: {str(e)}")


# Gradio Interface function for zero-shot classification
def zero_shot_classification(image, labels_text):
    try:
        # Convert image to PIL format
        PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')

        # Split labels_text into a list of labels
        labels = labels_text.split(",")

        # Perform zero-shot classification
        res = clip_classifier(images=PIL_image, candidate_labels=labels, hypothesis_template="This is a photo of a {}")

        # Format the result as a dictionary
        formatted_results = {dic["label"]: dic["score"] for dic in res}

        return formatted_results
    except Exception as e:
        return gr.Error(f"Error in zero-shot classification: {str(e)}")



def preprocessing_interface(original_image, brightness_slider, contrast_slider, saturation_slider, sharpness_slider, rotation_slider):
    try:
        # Convert NumPy array to PIL Image
        PIL_image = Image.fromarray(np.uint8(original_image)).convert('RGB')

        # Normalize slider values to be in the range [0, 1]
        brightness_normalized = brightness_slider / 100.0
        contrast_normalized = contrast_slider / 100.0
        saturation_normalized = saturation_slider / 100.0
        sharpness_normalized = sharpness_slider / 100.0

        # Apply preprocessing based on user input
        PIL_image = PIL_image.convert("RGB")
        PIL_image = PIL_image.rotate(rotation_slider)
        
        # Adjust brightness
        enhancer = ImageEnhance.Brightness(PIL_image)
        PIL_image = enhancer.enhance(brightness_normalized)

        # Adjust contrast
        enhancer = ImageEnhance.Contrast(PIL_image)
        PIL_image = enhancer.enhance(contrast_normalized)

        # Adjust saturation
        enhancer = ImageEnhance.Color(PIL_image)
        PIL_image = enhancer.enhance(saturation_normalized)

        # Adjust sharpness
        enhancer = ImageEnhance.Sharpness(PIL_image)
        PIL_image = enhancer.enhance(sharpness_normalized)

        # Return the processed image
        return PIL_image
    except Exception as e:
        return gr.Error(f"Error in preprocessing: {str(e)}")

def generate_captions(images):
    blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name)
    blip_processor = BlipProcessor.from_pretrained(blip_model_name)
    
    return [blip_model_instance.generate_caption(image) for image in images]


# Gradio Interfaces
zero_shot_classification_interface = gr.Interface(
    fn=zero_shot_classification,
    inputs=[
        gr.Image(label="Image Query", elem_id="image_input"),
        gr.Textbox(label="Labels (comma-separated)", elem_id="labels_input"),
    ],
    outputs=gr.Label(elem_id="label_image"),
)

text_to_image_interface = gr.Interface(
    fn=text_to_image_interface,
    inputs=[
        gr.Textbox(
            lines=2,
            label="Text Query",
            placeholder="Enter text here...",
        ),
        gr.Slider(0, 5, step=1, label="Top K Results"),
    ],
    outputs=[
        gr.Gallery(
            label="Text-to-Image Search Results",
            elem_id="gallery_text",
            grid_cols=2,
            height="auto",
        ),
        gr.Text(label="Result Information", elem_id="text_info"),
    ],
)

blip_model = BLIPImageCaptioning(blip_model_name)  # Instantiate the object
blip_captioning_interface = gr.Interface(
    fn=blip_model.generate_caption,  # Correct the method name
    inputs=gr.Image(label="Image for Captioning", elem_id="blip_caption_image"),
    outputs=gr.Textbox(label="Generated Captions", elem_id="blip_generated_captions", default=""),
)

preprocessing_interface = gr.Interface(
    fn=blip_model.preprocess_image,  # Correct the method name
    inputs=[
        gr.Image(label="Original Image", elem_id="original_image"),
    ],
    outputs=[
        gr.Image(label="Processed Image", elem_id="processed_image"),
    ],
)

# Tabbed Interface
app = gr.TabbedInterface(
    interface_list=[text_to_image_interface, zero_shot_classification_interface, blip_captioning_interface],
    tab_names=["Text-to-Image Search", "Zero-Shot Classification", "BLIP Image Captioning"],
)

# Launch the Gradio interface
app.launch(debug=True, share="true")