File size: 2,149 Bytes
73d50df
a6fcafd
1295972
73d50df
77e4539
7d27275
 
 
 
73d50df
0c2a5d3
1295972
7d27275
1295972
 
73d50df
 
 
7d27275
 
70548b8
f97cb5f
 
7d27275
f97cb5f
 
70548b8
 
7d27275
95cdd7b
7d27275
 
a6fcafd
 
 
 
 
 
589415a
 
528548d
c1c81f5
7d27275
528548d
7d27275
 
c1c81f5
73d50df
 
 
 
c7f0278
528548d
73d50df
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import numpy as np
import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG)

# Load the pre-trained model and feature extractor
model_name = "IDEA-Research/grounding-dino-tiny"
logging.info("Loading image processor and model...")
feature_extractor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)

# Define the prediction function
def predict(image):
    try:
        logging.info("Received image of type: %s", type(image))
        logging.debug("Image content: %s", image)
        
        # Use the 'composite' key to get the final image
        if isinstance(image, dict):
            image = image['composite']
        
        logging.debug("Converting to NumPy array...")
        image = np.array(image).astype('uint8')
        logging.debug("Converting NumPy array to PIL image...")
        image = Image.fromarray(image, 'RGBA').convert('RGB')
        logging.debug("Image converted successfully.")

        logging.info("Processing image...")
        inputs = feature_extractor(images=image, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        top_probs, top_idxs = probs.topk(3, dim=-1)
        top_probs = top_probs.detach().numpy()[0].tolist()  # Convert to list
        top_idxs = top_idxs.detach().numpy()[0].tolist()  # Convert to list
        top_classes = [model.config.id2label[idx] for idx in top_idxs]
        result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
        logging.info("Prediction successful.")
        return result
    except Exception as e:
        logging.error("Error during prediction: %s", e)
        return {"error": str(e)}

# Create the Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(),
    outputs=gr.JSON(),
    title="Drawing Classifier",
    description="Draw something and the model will try to identify it!"
)

# Launch the interface
iface.launch()