|
import gradio as gr |
|
import torch |
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
from PIL import Image |
|
import numpy as np |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
|
model_name = "IDEA-Research/grounding-dino-tiny" |
|
logging.info("Loading image processor and model...") |
|
feature_extractor = DetrImageProcessor.from_pretrained(model_name) |
|
model = DetrForObjectDetection.from_pretrained(model_name) |
|
|
|
|
|
def predict(image): |
|
try: |
|
logging.info("Received image of type: %s", type(image)) |
|
logging.debug("Image content: %s", image) |
|
|
|
|
|
if isinstance(image, dict): |
|
image = image['composite'] |
|
|
|
logging.debug("Converting to NumPy array...") |
|
image = np.array(image).astype('uint8') |
|
logging.debug("Converting NumPy array to PIL image...") |
|
image = Image.fromarray(image, 'RGBA').convert('RGB') |
|
logging.debug("Image converted successfully.") |
|
|
|
logging.info("Processing image...") |
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
top_probs, top_idxs = probs.topk(3, dim=-1) |
|
top_probs = top_probs.detach().numpy()[0].tolist() |
|
top_idxs = top_idxs.detach().numpy()[0].tolist() |
|
top_classes = [model.config.id2label[idx] for idx in top_idxs] |
|
result = {top_classes[i]: float(top_probs[i]) for i in range(3)} |
|
logging.info("Prediction successful.") |
|
return result |
|
except Exception as e: |
|
logging.error("Error during prediction: %s", e) |
|
return {"error": str(e)} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Sketchpad(), |
|
outputs=gr.JSON(), |
|
title="Drawing Classifier", |
|
description="Draw something and the model will try to identify it!" |
|
) |
|
|
|
|
|
iface.launch() |
|
|