codelion commited on
Commit
6934db6
·
verified ·
1 Parent(s): 76623dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  import json
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
- from io import StringIO
 
6
 
7
  # Function to process and visualize log probs
8
  def visualize_logprobs(json_input):
@@ -21,7 +22,7 @@ def visualize_logprobs(json_input):
21
  token = entry['token']
22
  logprob = entry['logprob']
23
  top_logprobs = entry['top_logprobs']
24
- # Extract top 3 alternatives
25
  top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
26
  row = [token, f"{logprob:.4f}"]
27
  for alt_token, alt_logprob in top_3:
@@ -41,19 +42,24 @@ def visualize_logprobs(json_input):
41
  plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
42
  plt.tight_layout()
43
 
44
- # Save plot to a buffer
45
- img_buffer = StringIO()
46
- plt.savefig(img_buffer, format='png', bbox_inches='tight')
47
- img_buffer.seek(0)
48
  plt.close()
49
 
 
 
 
 
 
50
  # Create a DataFrame for the table
51
  df = pd.DataFrame(
52
  table_data,
53
  columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
54
  )
55
 
56
- return img_buffer, df
57
 
58
  except Exception as e:
59
  return f"Error: {str(e)}", None
@@ -67,7 +73,7 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
67
  json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...")
68
 
69
  # Outputs
70
- plot_output = gr.Image(label="Log Probability Plot")
71
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
72
 
73
  # Button to trigger visualization
 
2
  import json
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
+ import io
6
+ import base64
7
 
8
  # Function to process and visualize log probs
9
  def visualize_logprobs(json_input):
 
22
  token = entry['token']
23
  logprob = entry['logprob']
24
  top_logprobs = entry['top_logprobs']
25
+ # Extract top 3 alternatives, sorted by log prob (most probable first)
26
  top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
27
  row = [token, f"{logprob:.4f}"]
28
  for alt_token, alt_logprob in top_3:
 
42
  plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
43
  plt.tight_layout()
44
 
45
+ # Save plot to a bytes buffer
46
+ buf = io.BytesIO()
47
+ plt.savefig(buf, format='png', bbox_inches='tight')
48
+ buf.seek(0)
49
  plt.close()
50
 
51
+ # Convert buffer to base64 for Gradio
52
+ img_bytes = buf.getvalue()
53
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
54
+ img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
55
+
56
  # Create a DataFrame for the table
57
  df = pd.DataFrame(
58
  table_data,
59
  columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
60
  )
61
 
62
+ return img_html, df
63
 
64
  except Exception as e:
65
  return f"Error: {str(e)}", None
 
73
  json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...")
74
 
75
  # Outputs
76
+ plot_output = gr.HTML(label="Log Probability Plot")
77
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
78
 
79
  # Button to trigger visualization