File size: 2,149 Bytes
73d50df a6fcafd 1295972 73d50df 77e4539 7d27275 73d50df 0c2a5d3 1295972 7d27275 1295972 73d50df 7d27275 70548b8 f97cb5f 7d27275 f97cb5f 70548b8 7d27275 95cdd7b 7d27275 a6fcafd 589415a 528548d c1c81f5 7d27275 528548d 7d27275 c1c81f5 73d50df c7f0278 528548d 73d50df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import numpy as np
import logging
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Load the pre-trained model and feature extractor
model_name = "IDEA-Research/grounding-dino-tiny"
logging.info("Loading image processor and model...")
feature_extractor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)
# Define the prediction function
def predict(image):
try:
logging.info("Received image of type: %s", type(image))
logging.debug("Image content: %s", image)
# Use the 'composite' key to get the final image
if isinstance(image, dict):
image = image['composite']
logging.debug("Converting to NumPy array...")
image = np.array(image).astype('uint8')
logging.debug("Converting NumPy array to PIL image...")
image = Image.fromarray(image, 'RGBA').convert('RGB')
logging.debug("Image converted successfully.")
logging.info("Processing image...")
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
top_probs, top_idxs = probs.topk(3, dim=-1)
top_probs = top_probs.detach().numpy()[0].tolist() # Convert to list
top_idxs = top_idxs.detach().numpy()[0].tolist() # Convert to list
top_classes = [model.config.id2label[idx] for idx in top_idxs]
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
logging.info("Prediction successful.")
return result
except Exception as e:
logging.error("Error during prediction: %s", e)
return {"error": str(e)}
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(),
outputs=gr.JSON(),
title="Drawing Classifier",
description="Draw something and the model will try to identify it!"
)
# Launch the interface
iface.launch()
|