codelion commited on
Commit
527fd08
·
verified ·
1 Parent(s): d8a969c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -12
app.py CHANGED
@@ -6,17 +6,26 @@ import io
6
  import base64
7
  import math
8
  import ast
 
 
 
 
 
9
 
10
  # Function to safely parse JSON or Python dictionary input
11
  def parse_input(json_input):
 
12
  try:
13
  # Try to parse as JSON first
14
  data = json.loads(json_input)
 
15
  return data
16
  except json.JSONDecodeError as e:
 
17
  try:
18
  # If JSON fails, try to parse as Python literal (e.g., with single quotes)
19
  data = ast.literal_eval(json_input)
 
20
  # Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes)
21
  def dict_to_json(obj):
22
  if isinstance(obj, dict):
@@ -25,10 +34,27 @@ def parse_input(json_input):
25
  return [dict_to_json(item) for item in obj]
26
  else:
27
  return obj
28
- return dict_to_json(data)
 
 
29
  except (SyntaxError, ValueError) as e:
 
30
  raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Function to process and visualize log probs
33
  def visualize_logprobs(json_input):
34
  try:
@@ -47,32 +73,37 @@ def visualize_logprobs(json_input):
47
  tokens = []
48
  logprobs = []
49
  for entry in content:
50
- if (
51
- "logprob" in entry
52
- and entry["logprob"] is not None
53
- and math.isfinite(entry["logprob"])
54
- ):
55
  tokens.append(entry["token"])
56
- logprobs.append(entry["logprob"])
 
 
57
 
58
  # Prepare table data, handling None in top_logprobs
59
  table_data = []
60
  for entry in content:
 
61
  # Only include entries with finite logprob and non-None top_logprobs
62
  if (
63
- "logprob" in entry
64
- and entry["logprob"] is not None
65
- and math.isfinite(entry["logprob"])
66
  and "top_logprobs" in entry
67
  and entry["top_logprobs"] is not None
68
  ):
69
  token = entry["token"]
70
- logprob = entry["logprob"]
71
  top_logprobs = entry["top_logprobs"]
 
 
 
 
 
 
72
 
73
  # Extract top 3 alternatives from top_logprobs
74
  top_3 = sorted(
75
- top_logprobs.items(), key=lambda x: x[1], reverse=True
76
  )[:3]
77
  row = [token, f"{logprob:.4f}"]
78
  for alt_token, alt_logprob in top_3:
@@ -149,6 +180,7 @@ def visualize_logprobs(json_input):
149
  return img_html, df, colored_text_html
150
 
151
  except Exception as e:
 
152
  return f"Error: {str(e)}", None, None
153
 
154
  # Gradio interface
 
6
  import base64
7
  import math
8
  import ast
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.DEBUG)
13
+ logger = logging.getLogger(__name__)
14
 
15
  # Function to safely parse JSON or Python dictionary input
16
  def parse_input(json_input):
17
+ logger.debug("Attempting to parse input: %s", json_input)
18
  try:
19
  # Try to parse as JSON first
20
  data = json.loads(json_input)
21
+ logger.debug("Successfully parsed as JSON")
22
  return data
23
  except json.JSONDecodeError as e:
24
+ logger.error("JSON parsing failed: %s", str(e))
25
  try:
26
  # If JSON fails, try to parse as Python literal (e.g., with single quotes)
27
  data = ast.literal_eval(json_input)
28
+ logger.debug("Successfully parsed as Python literal")
29
  # Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes)
30
  def dict_to_json(obj):
31
  if isinstance(obj, dict):
 
34
  return [dict_to_json(item) for item in obj]
35
  else:
36
  return obj
37
+ converted_data = dict_to_json(data)
38
+ logger.debug("Converted to JSON-compatible format")
39
+ return converted_data
40
  except (SyntaxError, ValueError) as e:
41
+ logger.error("Python literal parsing failed: %s", str(e))
42
  raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.")
43
 
44
+ # Function to ensure a value is a float, converting from string if necessary
45
+ def ensure_float(value):
46
+ if value is None:
47
+ return None
48
+ if isinstance(value, str):
49
+ try:
50
+ return float(value)
51
+ except ValueError:
52
+ logger.error("Failed to convert string '%s' to float", value)
53
+ return None
54
+ if isinstance(value, (int, float)):
55
+ return float(value)
56
+ return None
57
+
58
  # Function to process and visualize log probs
59
  def visualize_logprobs(json_input):
60
  try:
 
73
  tokens = []
74
  logprobs = []
75
  for entry in content:
76
+ logprob = ensure_float(entry.get("logprob", None))
77
+ if logprob is not None and math.isfinite(logprob):
 
 
 
78
  tokens.append(entry["token"])
79
+ logprobs.append(logprob)
80
+ else:
81
+ logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
82
 
83
  # Prepare table data, handling None in top_logprobs
84
  table_data = []
85
  for entry in content:
86
+ logprob = ensure_float(entry.get("logprob", None))
87
  # Only include entries with finite logprob and non-None top_logprobs
88
  if (
89
+ logprob is not None
90
+ and math.isfinite(logprob)
 
91
  and "top_logprobs" in entry
92
  and entry["top_logprobs"] is not None
93
  ):
94
  token = entry["token"]
95
+ logger.debug("Processing token: %s, logprob: %s (type: %s)", token, logprob, type(logprob))
96
  top_logprobs = entry["top_logprobs"]
97
+ # Ensure all values in top_logprobs are floats
98
+ finite_top_logprobs = {}
99
+ for key, value in top_logprobs.items():
100
+ float_value = ensure_float(value)
101
+ if float_value is not None and math.isfinite(float_value):
102
+ finite_top_logprobs[key] = float_value
103
 
104
  # Extract top 3 alternatives from top_logprobs
105
  top_3 = sorted(
106
+ finite_top_logprobs.items(), key=lambda x: x[1], reverse=True
107
  )[:3]
108
  row = [token, f"{logprob:.4f}"]
109
  for alt_token, alt_logprob in top_3:
 
180
  return img_html, df, colored_text_html
181
 
182
  except Exception as e:
183
+ logger.error("Visualization failed: %s", str(e))
184
  return f"Error: {str(e)}", None, None
185
 
186
  # Gradio interface