Babyloncoder commited on
Commit
0de4552
·
verified ·
1 Parent(s): b61227b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ import numpy as np
6
+ import pandas as pd
7
+ from plottable import Table
8
+ import matplotlib.pyplot as plt
9
+ from io import BytesIO
10
+ import random
11
+
12
+ def classify_image(upload, url, labels):
13
+ """
14
+ Classify the image either from an uploaded file or a URL with given labels.
15
+ """
16
+ # Check if an image file is uploaded
17
+ if upload is not None:
18
+ # Read the uploaded file as a byte stream
19
+ image = Image.open(BytesIO(upload))
20
+ # Otherwise, load the image from the provided URL
21
+ elif url is not None:
22
+ image = Image.open(requests.get(url, stream=True).raw)
23
+ # If neither, return a message prompting for an input
24
+ else:
25
+ return "Please upload an image or enter an image URL."
26
+
27
+ # Split the labels by comma and strip whitespace
28
+ labels_list = [label.strip() for label in labels.split(',')]
29
+
30
+ # Load the image classification model
31
+ image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")
32
+
33
+ # Perform inference
34
+ outputs = image_classifier(image, candidate_labels=labels_list)
35
+
36
+ # Process outputs
37
+ labels = [output["label"] for output in outputs]
38
+ scores = [output["score"] for output in outputs]
39
+
40
+ # Normalize scores to sum up to 100%
41
+ total_score = sum(scores)
42
+ normalized_scores = [round(score * 100 / total_score, 2) for score in scores]
43
+
44
+ # Plot the horizontal bar chart with different colors for each label
45
+ plt.figure(figsize=(10, 6))
46
+ colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))]
47
+ plt.barh(labels, normalized_scores, color=colors)
48
+ plt.xlabel('Score (%)')
49
+ plt.ylabel('Labels')
50
+ plt.title('Classification Results')
51
+ plt.gca().invert_yaxis() # Invert y-axis to display labels from top to bottom
52
+ plt.tight_layout()
53
+
54
+ # Save the plot to a BytesIO object
55
+ buf = BytesIO()
56
+ plt.savefig(buf, format='png')
57
+ buf.seek(0)
58
+
59
+ # Convert BytesIO object to image
60
+ result_image = Image.open(buf)
61
+
62
+ # Create a DataFrame for the classification results
63
+ df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores})
64
+
65
+ # Create a plottable table
66
+ tab = Table(df)
67
+
68
+ # Plot the table using matplotlib
69
+ fig, ax = plt.subplots(figsize=(6, 5))
70
+ ax.axis('tight')
71
+ ax.axis('off')
72
+ ax.table(cellText=df.values, colLabels=df.columns, loc='center')
73
+
74
+ # Save the figure to a BytesIO object
75
+ buf_table = BytesIO()
76
+ plt.savefig(buf_table, format='png')
77
+ buf_table.seek(0)
78
+
79
+ # Convert BytesIO object to image
80
+ result_table_image = Image.open(buf_table)
81
+
82
+ return result_image, result_table_image
83
+
84
+ # Create the Gradio interface
85
+ interface = gr.Interface(
86
+ fn=classify_image,
87
+ inputs=[
88
+ gr.File(type="binary", label="Upload Image"),
89
+ gr.Textbox(label="Or, enter Image URL"),
90
+ gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)")
91
+ ],
92
+ outputs=[
93
+ gr.Image(label="Classification Results (Bar Chart)"),
94
+ gr.Image(label="Classification Results (Table)")
95
+ ],
96
+ title="Image Classifier",
97
+ description="Upload an image or enter an image URL, then specify labels to classify the image."
98
+ )
99
+
100
+ # Launch the app
101
+ interface.launch()