Kseniia-Kholina commited on
Commit
88f36c7
·
verified ·
1 Parent(s): 2fdf8f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -6,6 +6,8 @@ import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import numpy as np
8
  import gradio as gr
 
 
9
 
10
  def get_heatmap(sequence):
11
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
@@ -64,11 +66,15 @@ def get_heatmap(sequence):
64
  plt.yticks(rotation=0)
65
  plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
66
 
67
- fig = plt.gcf()
68
- plt.close(fig)
69
-
70
- return fig
71
-
 
 
 
 
72
 
73
 
74
  demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")
 
6
  import seaborn as sns
7
  import numpy as np
8
  import gradio as gr
9
+ from io import BytesIO
10
+ from PIL import Image
11
 
12
  def get_heatmap(sequence):
13
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
 
66
  plt.yticks(rotation=0)
67
  plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
68
 
69
+ # Save the figure to a BytesIO object
70
+ buf = BytesIO()
71
+ plt.savefig(buf, format='png')
72
+ buf.seek(0)
73
+ plt.close()
74
+
75
+ # Convert BytesIO object to an image
76
+ img = Image.open(buf)
77
+ return img
78
 
79
 
80
  demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")