Embrace-Vision / app.py
Reverb's picture
Update app.py
0186079 verified
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, pipeline
from sentence_transformers import SentenceTransformer, util
import pickle
import numpy as np
from PIL import Image, ImageEnhance
import os
import io
import concurrent.futures
import warnings
warnings.filterwarnings("ignore")
class CLIPModelHandler:
def __init__(self, model_name):
self.model_name = model_name
self.img_names, self.img_emb = self.load_precomputed_embeddings()
def load_precomputed_embeddings(self):
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
with open(emb_filename, 'rb') as fIn:
img_names, img_emb = pickle.load(fIn)
return img_names, img_emb
def search_text(self, query, top_k=1):
model = CLIPModel.from_pretrained(self.model_name)
processor = CLIPProcessor.from_pretrained(self.model_name)
tokenizer = CLIPTokenizer.from_pretrained(self.model_name)
inputs = tokenizer([query], padding=True, return_tensors="pt")
query_emb = model.get_text_features(**inputs)
hits = util.semantic_search(query_emb, self.img_emb, top_k=top_k)[0]
images = [Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']])) for hit in hits]
return images
def search_image(self, image_path, top_k=1):
model = CLIPModel.from_pretrained(self.model_name)
processor = CLIPProcessor.from_pretrained(self.model_name)
# Load and preprocess the image
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
# Get the image features
outputs = model(**inputs)
image_emb = outputs.logits_per_image
# Perform semantic search
hits = util.semantic_search(image_emb, self.img_emb, top_k=top_k)[0]
# Retrieve and return the relevant images
result_images = []
for hit in hits:
img = Image.open(os.path.join("photos/", self.img_names[hit['corpus_id']]))
result_images.append(img)
return result_images
class BLIPImageCaptioning:
def __init__(self, blip_model_name):
self.blip_model_name = blip_model_name
def preprocess_image(self, image):
if isinstance(image, str):
return Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
return Image.fromarray(np.uint8(image)).convert('RGB')
else:
raise ValueError("Invalid input type for image. Supported types: str (file path) or np.ndarray.")
def generate_caption(self, image):
try:
model = BlipForConditionalGeneration.from_pretrained(self.blip_model_name)
processor = BlipProcessor.from_pretrained(self.blip_model_name)
raw_image = self.preprocess_image(image)
inputs = processor(raw_image, return_tensors="pt")
out = model.generate(**inputs)
unconditional_caption = processor.decode(out[0], skip_special_tokens=True)
return unconditional_caption
except Exception as e:
return {"error": str(e)}
def generate_captions_parallel(self, images):
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(self.generate_caption, images))
return results
# Initialize the CLIP model handler
clip_handler = CLIPModelHandler("openai/clip-vit-base-patch32")
# Initialize the zero-shot image classification pipeline
clip_classifier = pipeline("zero-shot-image-classification", model="openai/clip-vit-base-patch32")
# Load BLIP model directly
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model_name = "Salesforce/blip-image-captioning-base"
# Function for text-to-image search
def text_to_image_interface(query, top_k):
try:
# Perform text-to-image search
result_images = clip_handler.search_text(query, top_k)
# Resize images before displaying
result_images_resized = [image.resize((224, 224)) for image in result_images]
# Display more information about the results
result_info = [{"Image Name": os.path.basename(img_path)} for img_path in clip_handler.img_names]
return result_images_resized, result_info
except Exception as e:
return gr.Error(f"Error in text-to-image search: {str(e)}")
# Gradio Interface function for zero-shot classification
def zero_shot_classification(image, labels_text):
try:
# Convert image to PIL format
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
# Split labels_text into a list of labels
labels = labels_text.split(",")
# Perform zero-shot classification
res = clip_classifier(images=PIL_image, candidate_labels=labels, hypothesis_template="This is a photo of a {}")
# Format the result as a dictionary
formatted_results = {dic["label"]: dic["score"] for dic in res}
return formatted_results
except Exception as e:
return gr.Error(f"Error in zero-shot classification: {str(e)}")
def preprocessing_interface(original_image, brightness_slider, contrast_slider, saturation_slider, sharpness_slider, rotation_slider):
try:
# Convert NumPy array to PIL Image
PIL_image = Image.fromarray(np.uint8(original_image)).convert('RGB')
# Normalize slider values to be in the range [0, 1]
brightness_normalized = brightness_slider / 100.0
contrast_normalized = contrast_slider / 100.0
saturation_normalized = saturation_slider / 100.0
sharpness_normalized = sharpness_slider / 100.0
# Apply preprocessing based on user input
PIL_image = PIL_image.convert("RGB")
PIL_image = PIL_image.rotate(rotation_slider)
# Adjust brightness
enhancer = ImageEnhance.Brightness(PIL_image)
PIL_image = enhancer.enhance(brightness_normalized)
# Adjust contrast
enhancer = ImageEnhance.Contrast(PIL_image)
PIL_image = enhancer.enhance(contrast_normalized)
# Adjust saturation
enhancer = ImageEnhance.Color(PIL_image)
PIL_image = enhancer.enhance(saturation_normalized)
# Adjust sharpness
enhancer = ImageEnhance.Sharpness(PIL_image)
PIL_image = enhancer.enhance(sharpness_normalized)
# Return the processed image
return PIL_image
except Exception as e:
return gr.Error(f"Error in preprocessing: {str(e)}")
def generate_captions(images):
blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name)
blip_processor = BlipProcessor.from_pretrained(blip_model_name)
return [blip_model_instance.generate_caption(image) for image in images]
# Gradio Interfaces
zero_shot_classification_interface = gr.Interface(
fn=zero_shot_classification,
inputs=[
gr.Image(label="Image Query", elem_id="image_input"),
gr.Textbox(label="Labels (comma-separated)", elem_id="labels_input"),
],
outputs=gr.Label(elem_id="label_image"),
)
text_to_image_interface = gr.Interface(
fn=text_to_image_interface,
inputs=[
gr.Textbox(
lines=2,
label="Text Query",
placeholder="Enter text here...",
),
gr.Slider(0, 5, step=1, label="Top K Results"),
],
outputs=[
gr.Gallery(
label="Text-to-Image Search Results",
elem_id="gallery_text",
grid_cols=2,
height="auto",
),
gr.Text(label="Result Information", elem_id="text_info"),
],
)
blip_model = BLIPImageCaptioning(blip_model_name) # Instantiate the object
blip_captioning_interface = gr.Interface(
fn=blip_model.generate_caption, # Correct the method name
inputs=gr.Image(label="Image for Captioning", elem_id="blip_caption_image"),
outputs=gr.Textbox(label="Generated Captions", elem_id="blip_generated_captions", default=""),
)
preprocessing_interface = gr.Interface(
fn=blip_model.preprocess_image, # Correct the method name
inputs=[
gr.Image(label="Original Image", elem_id="original_image"),
],
outputs=[
gr.Image(label="Processed Image", elem_id="processed_image"),
],
)
# Tabbed Interface
app = gr.TabbedInterface(
interface_list=[text_to_image_interface, zero_shot_classification_interface, blip_captioning_interface],
tab_names=["Text-to-Image Search", "Zero-Shot Classification", "BLIP Image Captioning"],
)
# Launch the Gradio interface
app.launch(debug=True, share="true")