from transformers import OwlViTProcessor, OwlViTForObjectDetection from PIL import Image import torch import gradio as gr # Load model and processor processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") def predict(image): # Prepare image image = Image.open(image).convert("RGB") # Define inputs (zero-shot queries) text_queries = ["A Pokémon", "Pikachu", "Bulbasaur"] # Run the model inputs = processor(text=text_queries, images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) # Get predictions target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1) # Extract boxes boxes = [] for score, label, box in zip(results[0]["scores"], results[0]["labels"], results[0]["boxes"]): box = [round(i, 2) for i in box.tolist()] label_text = processor.tokenizer.decode([label]) boxes.append({"score": round(score.item(), 3), "label": label_text, "box": box}) return boxes # Create Gradio interface interface = gr.Interface(fn=predict, inputs="image", outputs="json") interface.launch()