Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, pipeline | |
from sentence_transformers import SentenceTransformer, util | |
import pickle | |
import numpy as np | |
from PIL import Image, ImageEnhance | |
import os | |
import io | |
import concurrent.futures | |
import warnings | |
warnings.filterwarnings("ignore") | |
class CLIPModelHandler: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
self.img_names, self.img_emb = self.load_precomputed_embeddings() | |
def load_precomputed_embeddings(self): | |
emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
with open(emb_filename, 'rb') as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
return img_names, img_emb | |
def search_text(self, query, top_k=1): | |
model = CLIPModel.from_pretrained(self.model_name) | |
processor = CLIPProcessor.from_pretrained(self.model_name) | |
tokenizer = CLIPTokenizer.from_pretrained(self.model_name) | |
inputs = tokenizer([query], padding=True, return_tensors="pt") | |
query_emb = model.get_text_features(**inputs) | |
hits = util.semantic_search(query_emb, self.img_emb, top_k=top_k)[0] | |
images = [Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) for hit in hits] | |
return images | |
def search_image(self, image_path, top_k=1): | |
model = CLIPModel.from_pretrained(self.model_name) | |
processor = CLIPProcessor.from_pretrained(self.model_name) | |
# Load and preprocess the image | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
# Get the image features | |
outputs = model(**inputs) | |
image_emb = outputs.logits_per_image | |
# Perform semantic search | |
hits = util.semantic_search(image_emb, self.img_emb, top_k=top_k)[0] | |
# Retrieve and return the relevant images | |
result_images = [] | |
for hit in hits: | |
img = Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) | |
result_images.append(img) | |
return result_images | |
class BLIPImageCaptioning: | |
def __init__(self, blip_model_name): | |
self.blip_model_name = blip_model_name | |
def preprocess_image(self, image): | |
if isinstance(image, str): | |
return Image.open(image).convert('RGB') | |
elif isinstance(image, np.ndarray): | |
return Image.fromarray(np.uint8(image)).convert('RGB') | |
else: | |
raise ValueError("Invalid input type for image. Supported types: str (file path) or np.ndarray.") | |
def generate_caption(self, image): | |
try: | |
model = BlipForConditionalGeneration.from_pretrained(self.blip_model_name) | |
processor = BlipProcessor.from_pretrained(self.blip_model_name) | |
raw_image = self.preprocess_image(image) | |
inputs = processor(raw_image, return_tensors="pt") | |
out = model.generate(**inputs) | |
unconditional_caption = processor.decode(out[0], skip_special_tokens=True) | |
return unconditional_caption | |
except Exception as e: | |
return {"error": str(e)} | |
def generate_captions_parallel(self, images): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
results = list(executor.map(self.generate_caption, images)) | |
return results | |
# Initialize the CLIP model handler | |
clip_handler = CLIPModelHandler("openai/clip-vit-base-patch32") | |
# Initialize the zero-shot image classification pipeline | |
clip_classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32") | |
# Load BLIP model directly | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model_name = "Salesforce/blip-image-captioning-base" | |
# Function for text-to-image search | |
def text_to_image_interface(query, top_k): | |
try: | |
# Perform text-to-image search | |
result_images = clip_handler.search_text(query, top_k) | |
# Resize images before displaying | |
result_images_resized = [image.resize((224, 224)) for image in result_images] | |
# Display more information about the results | |
result_info = [{"Image Name": os.path.basename(img_path)} for img_path in clip_handler.img_names] | |
return result_images_resized, result_info | |
except Exception as e: | |
return gr.Error(f"Error in text-to-image search: {str(e)}") | |
# Gradio Interface function for zero-shot classification | |
def zero_shot_classification(image, labels_text): | |
try: | |
# Convert image to PIL format | |
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') | |
# Split labels_text into a list of labels | |
labels = labels_text.split(",") | |
# Perform zero-shot classification | |
res = clip_classifier(images=PIL_image, candidate_labels=labels, hypothesis_template="This is a photo of a {}") | |
# Format the result as a dictionary | |
formatted_results = {dic["label"]: dic["score"] for dic in res} | |
return formatted_results | |
except Exception as e: | |
return gr.Error(f"Error in zero-shot classification: {str(e)}") | |
def preprocessing_interface(original_image, brightness_slider, contrast_slider, saturation_slider, sharpness_slider, rotation_slider): | |
try: | |
# Convert NumPy array to PIL Image | |
PIL_image = Image.fromarray(np.uint8(original_image)).convert('RGB') | |
# Normalize slider values to be in the range [0, 1] | |
brightness_normalized = brightness_slider / 100.0 | |
contrast_normalized = contrast_slider / 100.0 | |
saturation_normalized = saturation_slider / 100.0 | |
sharpness_normalized = sharpness_slider / 100.0 | |
# Apply preprocessing based on user input | |
PIL_image = PIL_image.convert("RGB") | |
PIL_image = PIL_image.rotate(rotation_slider) | |
# Adjust brightness | |
enhancer = ImageEnhance.Brightness(PIL_image) | |
PIL_image = enhancer.enhance(brightness_normalized) | |
# Adjust contrast | |
enhancer = ImageEnhance.Contrast(PIL_image) | |
PIL_image = enhancer.enhance(contrast_normalized) | |
# Adjust saturation | |
enhancer = ImageEnhance.Color(PIL_image) | |
PIL_image = enhancer.enhance(saturation_normalized) | |
# Adjust sharpness | |
enhancer = ImageEnhance.Sharpness(PIL_image) | |
PIL_image = enhancer.enhance(sharpness_normalized) | |
# Return the processed image | |
return PIL_image | |
except Exception as e: | |
return gr.Error(f"Error in preprocessing: {str(e)}") | |
def generate_captions(images): | |
blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name) | |
blip_processor = BlipProcessor.from_pretrained(blip_model_name) | |
return [blip_model_instance.generate_caption(image) for image in images] | |
# Gradio Interfaces | |
zero_shot_classification_interface = gr.Interface( | |
fn=zero_shot_classification, | |
inputs=[ | |
gr.Image(label="Image Query", elem_id="image_input"), | |
gr.Textbox(label="Labels (comma-separated)", elem_id="labels_input"), | |
], | |
outputs=gr.Label(elem_id="label_image"), | |
) | |
text_to_image_interface = gr.Interface( | |
fn=text_to_image_interface, | |
inputs=[ | |
gr.Textbox( | |
lines=2, | |
label="Text Query", | |
placeholder="Enter text here...", | |
), | |
gr.Slider(0, 5, step=1, label="Top K Results"), | |
], | |
outputs=[ | |
gr.Gallery( | |
label="Text-to-Image Search Results", | |
elem_id="gallery_text", | |
grid_cols=2, | |
height="auto", | |
), | |
gr.Text(label="Result Information", elem_id="text_info"), | |
], | |
) | |
blip_model = BLIPImageCaptioning(blip_model_name) # Instantiate the object | |
blip_captioning_interface = gr.Interface( | |
fn=blip_model.generate_caption, # Correct the method name | |
inputs=gr.Image(label="Image for Captioning", elem_id="blip_caption_image"), | |
outputs=gr.Textbox(label="Generated Captions", elem_id="blip_generated_captions", default=""), | |
) | |
preprocessing_interface = gr.Interface( | |
fn=blip_model.preprocess_image, # Correct the method name | |
inputs=[ | |
gr.Image(label="Original Image", elem_id="original_image"), | |
], | |
outputs=[ | |
gr.Image(label="Processed Image", elem_id="processed_image"), | |
], | |
) | |
# Tabbed Interface | |
app = gr.TabbedInterface( | |
interface_list=[text_to_image_interface, zero_shot_classification_interface, blip_captioning_interface], | |
tab_names=["Text-to-Image Search", "Zero-Shot Classification", "BLIP Image Captioning"], | |
) | |
# Launch the Gradio interface | |
app.launch(debug=True, share="true") |