import gradio as gr import numpy as np import torch from PIL import Image import open_clip from datasets import Dataset import os # Set environment variable to work around OpenMP runtime issue os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' # Load the model and processor model, processor = open_clip.create_model_from_pretrained('hf-hub:imageomics/bioclip') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Load the dataset embedding_path = "./data/embeddings_bioclip_False" ds = Dataset.load_from_disk(embedding_path) # Load FAISS indexes cosine_faiss_path = os.path.join(embedding_path, "embeddings_cosine.faiss") l2_faiss_path = os.path.join(embedding_path, "embeddings_l2.faiss") ds.load_faiss_index("embeddings_cosine", cosine_faiss_path) ds.load_faiss_index("embeddings_l2", l2_faiss_path) def majority_vote(classes, scores=None): if scores is None: scores = np.ones_like(classes) unique_classes, class_counts = np.unique(classes, return_counts=True) class_weights = {cls: 0 for cls in unique_classes} for cls, weight in zip(classes, scores): class_weights[cls] += weight majority_class = max(class_weights, key=class_weights.get) return majority_class def classify_example(example, index="embeddings_l2", k=10, vote_scores=True): features = np.array(example["embeddings"], dtype=np.float32) scores, nearest = ds.get_nearest_examples(index, features, k) class_labels = [ds.features["label"].names[c] for c in nearest["label"]] if vote_scores: prediction = majority_vote(class_labels, scores) else: prediction = majority_vote(class_labels) return prediction, class_labels, nearest["file"] def embed_image(image: Image.Image): processed_images = processor(image).unsqueeze(0) with torch.no_grad(): embeddings = model.encode_image(processed_images.to(device)) return {"embeddings": embeddings.cpu()} def predict(image): embedding = embed_image(image) prediction, class_labels, file_paths = classify_example(embedding) return prediction, ", ".join(class_labels[:3]), ", ".join(file_paths[:3]) iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Prediction"), gr.Textbox(label="Top 3 Classes"), gr.Textbox(label="Top 3 File Paths") ], title="BioClip Image Classification", description="Upload an image to get a prediction using the BioClip model." ) iface.launch()