codelion commited on
Commit
c28bdaa
·
verified ·
1 Parent(s): 7e141c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -50
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import io
6
  import base64
7
  import ast
 
8
 
9
  # Function to process and visualize log probs
10
  def visualize_logprobs(json_input):
@@ -16,19 +17,34 @@ def visualize_logprobs(json_input):
16
  # If JSON fails, try to parse as Python literal (e.g., with single quotes)
17
  data = ast.literal_eval(json_input)
18
 
19
- # Extract tokens and log probs, skipping None values
20
- tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None]
21
- logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None]
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Prepare data for the table
24
  table_data = []
25
- for entry in data['content']:
26
- if entry['logprob'] is not None:
27
  token = entry['token']
28
  logprob = entry['logprob']
29
- top_logprobs = entry['top_logprobs']
30
- # Extract top 3 alternatives, sorted by log prob (most probable first)
31
- top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
 
 
 
32
  row = [token, f"{logprob:.4f}"]
33
  for alt_token, alt_logprob in top_3:
34
  row.append(f"{alt_token}: {alt_logprob:.4f}")
@@ -37,56 +53,61 @@ def visualize_logprobs(json_input):
37
  row.append("")
38
  table_data.append(row)
39
 
40
- # Create the plot
41
- plt.figure(figsize=(10, 5))
42
- plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
43
- plt.title("Log Probabilities of Generated Tokens")
44
- plt.xlabel("Token Position")
45
- plt.ylabel("Log Probability")
46
- plt.grid(True)
47
- plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
48
- plt.tight_layout()
49
-
50
- # Save plot to a bytes buffer
51
- buf = io.BytesIO()
52
- plt.savefig(buf, format='png', bbox_inches='tight')
53
- buf.seek(0)
54
- plt.close()
55
-
56
- # Convert buffer to base64 for Gradio
57
- img_bytes = buf.getvalue()
58
- img_base64 = base64.b64encode(img_bytes).decode('utf-8')
59
- img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
 
 
 
60
 
61
  # Create a DataFrame for the table
62
  df = pd.DataFrame(
63
  table_data,
64
  columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
65
- )
66
 
67
  # Generate colored text based on log probabilities
68
- # Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
69
- min_logprob = min(logprobs)
70
- max_logprob = max(logprobs)
71
- if max_logprob == min_logprob:
72
- normalized_probs = [0.5] * len(logprobs) # Avoid division by zero
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
- normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
75
-
76
- # Create HTML for colored text
77
- colored_text = ""
78
- for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
79
- # Map normalized probability to RGB color (green for high confidence, red for low)
80
- # Use a simple linear interpolation: green (0, 255, 0) to red (255, 0, 0)
81
- r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
82
- g = int(255 * norm_prob) # Green decreases as uncertainty increases
83
- b = 0 # Blue stays 0 for simplicity
84
- color = f'rgb({r}, {g}, {b})'
85
- colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
86
- if i < len(tokens) - 1:
87
- colored_text += " " # Add space between tokens
88
-
89
- colored_text_html = f'<p>{colored_text}</p>'
90
 
91
  return img_html, df, colored_text_html
92
 
 
5
  import io
6
  import base64
7
  import ast
8
+ import math
9
 
10
  # Function to process and visualize log probs
11
  def visualize_logprobs(json_input):
 
17
  # If JSON fails, try to parse as Python literal (e.g., with single quotes)
18
  data = ast.literal_eval(json_input)
19
 
20
+ # Ensure data is a list or dictionary with 'content'
21
+ if isinstance(data, dict) and 'content' in data:
22
+ content = data['content']
23
+ elif isinstance(data, list):
24
+ content = data
25
+ else:
26
+ raise ValueError("Input must be a list or dictionary with 'content' key")
27
+
28
+ # Extract tokens and log probs, skipping None values and handling non-finite values
29
+ tokens = []
30
+ logprobs = []
31
+ for entry in content:
32
+ if entry['logprob'] is not None and math.isfinite(entry['logprob']):
33
+ tokens.append(entry['token'])
34
+ logprobs.append(entry['logprob'])
35
 
36
  # Prepare data for the table
37
  table_data = []
38
+ for entry in content:
39
+ if entry['logprob'] is not None and math.isfinite(entry['logprob']):
40
  token = entry['token']
41
  logprob = entry['logprob']
42
+ top_logprobs = entry['top_logprobs'] or {}
43
+
44
+ # Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs
45
+ finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)}
46
+ # Extract top 3 finite alternatives, sorted by log prob (most probable first)
47
+ top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
48
  row = [token, f"{logprob:.4f}"]
49
  for alt_token, alt_logprob in top_3:
50
  row.append(f"{alt_token}: {alt_logprob:.4f}")
 
53
  row.append("")
54
  table_data.append(row)
55
 
56
+ # Create the plot (only for finite log probs)
57
+ if logprobs:
58
+ plt.figure(figsize=(10, 5))
59
+ plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
60
+ plt.title("Log Probabilities of Generated Tokens")
61
+ plt.xlabel("Token Position")
62
+ plt.ylabel("Log Probability")
63
+ plt.grid(True)
64
+ plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
65
+ plt.tight_layout()
66
+
67
+ # Save plot to a bytes buffer
68
+ buf = io.BytesIO()
69
+ plt.savefig(buf, format='png', bbox_inches='tight')
70
+ buf.seek(0)
71
+ plt.close()
72
+
73
+ # Convert buffer to base64 for Gradio
74
+ img_bytes = buf.getvalue()
75
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
76
+ img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
77
+ else:
78
+ img_html = "No finite log probabilities to plot."
79
 
80
  # Create a DataFrame for the table
81
  df = pd.DataFrame(
82
  table_data,
83
  columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
84
+ ) if table_data else None
85
 
86
  # Generate colored text based on log probabilities
87
+ if logprobs:
88
+ # Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
89
+ min_logprob = min(logprobs)
90
+ max_logprob = max(logprobs)
91
+ if max_logprob == min_logprob:
92
+ normalized_probs = [0.5] * len(logprobs) # Avoid division by zero
93
+ else:
94
+ normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
95
+
96
+ # Create HTML for colored text
97
+ colored_text = ""
98
+ for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
99
+ # Map normalized probability to RGB color (green for high confidence, red for low)
100
+ r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
101
+ g = int(255 * norm_prob) # Green decreases as uncertainty increases
102
+ b = 0 # Blue stays 0 for simplicity
103
+ color = f'rgb({r}, {g}, {b})'
104
+ colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
105
+ if i < len(tokens) - 1:
106
+ colored_text += " " # Add space between tokens
107
+
108
+ colored_text_html = f'<p>{colored_text}</p>'
109
  else:
110
+ colored_text_html = "No finite log probabilities to display."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  return img_html, df, colored_text_html
113