File size: 3,320 Bytes
0de4552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ca1c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import pipeline
from PIL import Image
import requests
import numpy as np
import pandas as pd
from plottable import Table
import matplotlib.pyplot as plt
from io import BytesIO
import random

def classify_image(upload, url, labels):
    """
    Classify the image either from an uploaded file or a URL with given labels.
    """
    # Check if an image file is uploaded
    if upload is not None:
        # Read the uploaded file as a byte stream
        image = Image.open(BytesIO(upload))
    # Otherwise, load the image from the provided URL
    elif url is not None:
        image = Image.open(requests.get(url, stream=True).raw)
    # If neither, return a message prompting for an input
    else:
        return "Please upload an image or enter an image URL."

    # Split the labels by comma and strip whitespace
    labels_list = [label.strip() for label in labels.split(',')]

    # Load the image classification model
    image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")

    # Perform inference
    outputs = image_classifier(image, candidate_labels=labels_list)

    # Process outputs
    labels = [output["label"] for output in outputs]
    scores = [output["score"] for output in outputs]

    # Normalize scores to sum up to 100%
    total_score = sum(scores)
    normalized_scores = [round(score * 100 / total_score, 2) for score in scores]

    # Plot the horizontal bar chart with different colors for each label
    plt.figure(figsize=(10, 6))
    colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))]
    plt.barh(labels, normalized_scores, color=colors)
    plt.xlabel('Score (%)')
    plt.ylabel('Labels')
    plt.title('Classification Results')
    plt.gca().invert_yaxis()  # Invert y-axis to display labels from top to bottom
    plt.tight_layout()

    # Save the plot to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)

    # Convert BytesIO object to image
    result_image = Image.open(buf)

    # Create a DataFrame for the classification results
    df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores})

    # Create a plottable table
    tab = Table(df)

    # Plot the table using matplotlib
    fig, ax = plt.subplots(figsize=(6, 5))
    ax.axis('tight')
    ax.axis('off')
    ax.table(cellText=df.values, colLabels=df.columns, loc='center')

    # Save the figure to a BytesIO object
    buf_table = BytesIO()
    plt.savefig(buf_table, format='png')
    buf_table.seek(0)

    # Convert BytesIO object to image
    result_table_image = Image.open(buf_table)

    return result_image, result_table_image

# Create the Gradio interface
interface = gr.Interface(
    fn=classify_image,
    inputs=[
        gr.File(type="binary", label="Upload Image"),
        gr.Textbox(label="Or, enter Image URL"),
        gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)")
    ],
    outputs=[
        gr.Image(label="Classification Results (Bar Chart)"),
        gr.Image(label="Classification Results (Table)")
    ],
    title="Image Classifier",
    description="Upload an image or enter an image URL, then specify labels to classify the image."
)

# Launch the interface
interface.launch()