Babyloncoder's picture
Create app.py
f831e78 verified
raw
history blame
2.53 kB
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
# Function to perform classification and create pie and bar charts
def classify_and_plot(text, labels):
# Splitting labels entered by user
labels_list = labels.split(',')
# Load the zero-shot classification pipeline with the specific model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Perform classification
result = classifier(text, labels_list)
# Extract labels and scores
labels = result['labels']
scores = result['scores']
# Generate a colour for each label
colors = plt.cm.viridis(np.linspace(0, 1, len(labels)))
# Create a pie chart
fig1, ax1 = plt.subplots()
wedges, texts = ax1.pie(scores, startangle=140, colors=colors)
ax1.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
ax1.set_title('Pie Chart')
# Prepare labels with percentages for the pie chart legend
legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)]
ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5))
# Save the pie chart to a buffer
buf1 = io.BytesIO()
plt.savefig(buf1, format='png', bbox_inches='tight')
buf1.seek(0)
pie_chart = Image.open(buf1)
pie_chart_array = np.array(pie_chart)
plt.close()
# Create a bar chart
fig2, ax2 = plt.subplots()
y_pos = np.arange(len(labels))
ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue')
ax2.set_xticks(y_pos)
ax2.set_xticklabels(labels, rotation=45, ha="right")
ax2.set_ylabel('Scores')
ax2.set_title('Bar Chart')
# Save the bar chart to a buffer
buf2 = io.BytesIO()
plt.savefig(buf2, format='png', bbox_inches='tight')
buf2.seek(0)
bar_chart = Image.open(buf2)
bar_chart_array = np.array(bar_chart)
plt.close()
return pie_chart_array, bar_chart_array
# Create a Gradio interface
iface = gr.Interface(
fn=classify_and_plot,
inputs=["text", "text"],
outputs=["image", "image"],
title="Zero-Shot Classification with Pie and Bar Charts",
description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores."
)
# Launch the interface with the 'share' argument
iface.launch(share=True)