import gradio as gr from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, pipeline from sentence_transformers import SentenceTransformer, util import pickle import numpy as np from PIL import Image, ImageEnhance import os import io import concurrent.futures import warnings warnings.filterwarnings("ignore") class CLIPModelHandler: def __init__(self, model_name): self.model_name = model_name self.img_names, self.img_emb = self.load_precomputed_embeddings() def load_precomputed_embeddings(self): emb_filename = 'unsplash-25k-photos-embeddings.pkl' with open(emb_filename, 'rb') as fIn: img_names, img_emb = pickle.load(fIn) return img_names, img_emb def search_text(self, query, top_k=1): model = CLIPModel.from_pretrained(self.model_name) processor = CLIPProcessor.from_pretrained(self.model_name) tokenizer = CLIPTokenizer.from_pretrained(self.model_name) inputs = tokenizer([query], padding=True, return_tensors="pt") query_emb = model.get_text_features(**inputs) hits = util.semantic_search(query_emb, self.img_emb, top_k=top_k)[0] images = [Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) for hit in hits] return images def search_image(self, image_path, top_k=1): model = CLIPModel.from_pretrained(self.model_name) processor = CLIPProcessor.from_pretrained(self.model_name) # Load and preprocess the image image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt") # Get the image features outputs = model(**inputs) image_emb = outputs.logits_per_image # Perform semantic search hits = util.semantic_search(image_emb, self.img_emb, top_k=top_k)[0] # Retrieve and return the relevant images result_images = [] for hit in hits: img = Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) result_images.append(img) return result_images class BLIPImageCaptioning: def __init__(self, blip_model_name): self.blip_model_name = blip_model_name def preprocess_image(self, image): if isinstance(image, str): return Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): return Image.fromarray(np.uint8(image)).convert('RGB') else: raise ValueError("Invalid input type for image. Supported types: str (file path) or np.ndarray.") def generate_caption(self, image): try: model = BlipForConditionalGeneration.from_pretrained(self.blip_model_name) processor = BlipProcessor.from_pretrained(self.blip_model_name) raw_image = self.preprocess_image(image) inputs = processor(raw_image, return_tensors="pt") out = model.generate(**inputs) unconditional_caption = processor.decode(out[0], skip_special_tokens=True) return unconditional_caption except Exception as e: return {"error": str(e)} def generate_captions_parallel(self, images): with concurrent.futures.ThreadPoolExecutor() as executor: results = list(executor.map(self.generate_caption, images)) return results # Initialize the CLIP model handler clip_handler = CLIPModelHandler("openai/clip-vit-base-patch32") # Initialize the zero-shot image classification pipeline clip_classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32") # Load BLIP model directly blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model_name = "Salesforce/blip-image-captioning-base" # Function for text-to-image search def text_to_image_interface(query, top_k): try: # Perform text-to-image search result_images = clip_handler.search_text(query, top_k) # Resize images before displaying result_images_resized = [image.resize((224, 224)) for image in result_images] # Display more information about the results result_info = [{"Image Name": os.path.basename(img_path)} for img_path in clip_handler.img_names] return result_images_resized, result_info except Exception as e: return gr.Error(f"Error in text-to-image search: {str(e)}") # Gradio Interface function for zero-shot classification def zero_shot_classification(image, labels_text): try: # Convert image to PIL format PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') # Split labels_text into a list of labels labels = labels_text.split(",") # Perform zero-shot classification res = clip_classifier(images=PIL_image, candidate_labels=labels, hypothesis_template="This is a photo of a {}") # Format the result as a dictionary formatted_results = {dic["label"]: dic["score"] for dic in res} return formatted_results except Exception as e: return gr.Error(f"Error in zero-shot classification: {str(e)}") def preprocessing_interface(original_image, brightness_slider, contrast_slider, saturation_slider, sharpness_slider, rotation_slider): try: # Convert NumPy array to PIL Image PIL_image = Image.fromarray(np.uint8(original_image)).convert('RGB') # Normalize slider values to be in the range [0, 1] brightness_normalized = brightness_slider / 100.0 contrast_normalized = contrast_slider / 100.0 saturation_normalized = saturation_slider / 100.0 sharpness_normalized = sharpness_slider / 100.0 # Apply preprocessing based on user input PIL_image = PIL_image.convert("RGB") PIL_image = PIL_image.rotate(rotation_slider) # Adjust brightness enhancer = ImageEnhance.Brightness(PIL_image) PIL_image = enhancer.enhance(brightness_normalized) # Adjust contrast enhancer = ImageEnhance.Contrast(PIL_image) PIL_image = enhancer.enhance(contrast_normalized) # Adjust saturation enhancer = ImageEnhance.Color(PIL_image) PIL_image = enhancer.enhance(saturation_normalized) # Adjust sharpness enhancer = ImageEnhance.Sharpness(PIL_image) PIL_image = enhancer.enhance(sharpness_normalized) # Return the processed image return PIL_image except Exception as e: return gr.Error(f"Error in preprocessing: {str(e)}") def generate_captions(images): blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name) blip_processor = BlipProcessor.from_pretrained(blip_model_name) return [blip_model_instance.generate_caption(image) for image in images] # Gradio Interfaces zero_shot_classification_interface = gr.Interface( fn=zero_shot_classification, inputs=[ gr.Image(label="Image Query", elem_id="image_input"), gr.Textbox(label="Labels (comma-separated)", elem_id="labels_input"), ], outputs=gr.Label(elem_id="label_image"), ) text_to_image_interface = gr.Interface( fn=text_to_image_interface, inputs=[ gr.Textbox( lines=2, label="Text Query", placeholder="Enter text here...", ), gr.Slider(0, 5, step=1, label="Top K Results"), ], outputs=[ gr.Gallery( label="Text-to-Image Search Results", elem_id="gallery_text", grid_cols=2, height="auto", ), gr.Text(label="Result Information", elem_id="text_info"), ], ) blip_model = BLIPImageCaptioning(blip_model_name) # Instantiate the object blip_captioning_interface = gr.Interface( fn=blip_model.generate_caption, # Correct the method name inputs=gr.Image(label="Image for Captioning", elem_id="blip_caption_image"), outputs=gr.Textbox(label="Generated Captions", elem_id="blip_generated_captions", default=""), ) preprocessing_interface = gr.Interface( fn=blip_model.preprocess_image, # Correct the method name inputs=[ gr.Image(label="Original Image", elem_id="original_image"), ], outputs=[ gr.Image(label="Processed Image", elem_id="processed_image"), ], ) # Tabbed Interface app = gr.TabbedInterface( interface_list=[text_to_image_interface, zero_shot_classification_interface, blip_captioning_interface], tab_names=["Text-to-Image Search", "Zero-Shot Classification", "BLIP Image Captioning"], ) # Launch the Gradio interface app.launch(debug=True, share="true")