Spaces:
Running
Running
vickeee465
commited on
Commit
·
bb92bc0
1
Parent(s):
dddc9f2
sunburst chart
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ from transformers import AutoModelForSequenceClassification
|
|
| 8 |
from transformers import AutoTokenizer
|
| 9 |
import gradio as gr
|
| 10 |
import matplotlib.pyplot as plt
|
|
|
|
| 11 |
import seaborn as sns
|
| 12 |
|
| 13 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
|
@@ -99,6 +100,31 @@ def plot_emotion_heatmap(heatmap_data):
|
|
| 99 |
plt.tight_layout()
|
| 100 |
return fig
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def plot_emotion_barplot(heatmap_data):
|
| 104 |
most_probable_emotions = heatmap_data.idxmax(axis=0)
|
|
@@ -133,8 +159,9 @@ def predict_wrapper(text, language):
|
|
| 133 |
print(results_heatmap)
|
| 134 |
|
| 135 |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
|
|
|
|
| 136 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
| 137 |
-
return results, figure, output_info
|
| 138 |
|
| 139 |
|
| 140 |
with gr.Blocks() as demo:
|
|
@@ -163,7 +190,7 @@ with gr.Blocks() as demo:
|
|
| 163 |
predict_button.click(
|
| 164 |
fn=predict_wrapper,
|
| 165 |
inputs=[input_text, language_choice],
|
| 166 |
-
outputs=[result_table, plot, model_info]
|
| 167 |
)
|
| 168 |
|
| 169 |
if __name__ == "__main__":
|
|
|
|
| 8 |
from transformers import AutoTokenizer
|
| 9 |
import gradio as gr
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
+
import plotly.express as px
|
| 12 |
import seaborn as sns
|
| 13 |
|
| 14 |
PATH = '/data/' # at least 150GB storage needs to be attached
|
|
|
|
| 100 |
plt.tight_layout()
|
| 101 |
return fig
|
| 102 |
|
| 103 |
+
def plot_sunburst_chart(heatmap_data):
|
| 104 |
+
data = []
|
| 105 |
+
for item in heatmap_data:
|
| 106 |
+
sentence = item['sentence']
|
| 107 |
+
emotions = item['emotions']
|
| 108 |
+
for i, score in enumerate(emotions):
|
| 109 |
+
data.append({
|
| 110 |
+
'root': 'All Sentences',
|
| 111 |
+
'sentence': sentence,
|
| 112 |
+
'emotion': id2label[i],
|
| 113 |
+
'score': float(score)
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
df = pd.DataFrame(data)
|
| 117 |
+
|
| 118 |
+
# Plot sunburst
|
| 119 |
+
fig = px.sunburst(
|
| 120 |
+
df,
|
| 121 |
+
path=['root', 'sentence', 'emotion'],
|
| 122 |
+
values='score',
|
| 123 |
+
color='emotion',
|
| 124 |
+
title='Sentence-level Emotion Confidences'
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return fig
|
| 128 |
|
| 129 |
def plot_emotion_barplot(heatmap_data):
|
| 130 |
most_probable_emotions = heatmap_data.idxmax(axis=0)
|
|
|
|
| 159 |
print(results_heatmap)
|
| 160 |
|
| 161 |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap))
|
| 162 |
+
sunburst_chart = plot_sunburst_chart(results_heatmap)
|
| 163 |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.'
|
| 164 |
+
return results, figure, sunburst_chart, output_info
|
| 165 |
|
| 166 |
|
| 167 |
with gr.Blocks() as demo:
|
|
|
|
| 190 |
predict_button.click(
|
| 191 |
fn=predict_wrapper,
|
| 192 |
inputs=[input_text, language_choice],
|
| 193 |
+
outputs=[result_table, plot, sunburst_chart, model_info]
|
| 194 |
)
|
| 195 |
|
| 196 |
if __name__ == "__main__":
|