bgamazay commited on
Commit
8c74b2d
·
verified ·
1 Parent(s): 2be0753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -39,16 +39,22 @@ def make_link(mname):
39
 
40
  def get_plots(task):
41
  df = pd.read_csv('data/energy/' + task)
42
- df['energy_score'] = df['energy_score'].astype(int)
 
 
 
 
 
43
  df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1])
44
 
45
- color_map = {1: "red", 2: "orange", 3: "yellow", 4: "lightgreen", 5: "green"}
 
46
 
47
  fig = px.scatter(
48
  df,
49
- x="total_gpu_energy", # Ensure correct column for x-axis
50
- y="Display Model", # Keep model name for y-axis
51
- color="energy_score", # Ensure correct column for point color
52
  custom_data=['energy_score'],
53
  height=500,
54
  width=800,
@@ -70,18 +76,19 @@ def get_all_plots():
70
  df = pd.read_csv('data/energy/' + task)
71
  if df.columns[0].startswith("Unnamed:"):
72
  df = df.iloc[:, 1:]
73
- df['energy_score'] = df['energy_score'].astype(int)
 
74
  df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1])
75
  all_df = pd.concat([all_df, df], ignore_index=True)
76
  all_df = all_df.drop_duplicates(subset=['model'])
77
 
78
- color_map = {1: "red", 2: "orange", 3: "yellow", 4: "lightgreen", 5: "green"}
79
 
80
  fig = px.scatter(
81
  all_df,
82
- x="total_gpu_energy", # Ensure correct column for x-axis
83
  y="Display Model",
84
- color="energy_score", # Ensure correct column for point color
85
  custom_data=['energy_score'],
86
  height=500,
87
  width=800,
@@ -241,9 +248,10 @@ Click through the tasks below to see how different models measure up in terms of
241
  with gr.TabItem("All Tasks 💡"):
242
  with gr.Row():
243
  with gr.Column():
244
- plot = gr.Plot(get_all_plots)
 
245
  with gr.Column():
246
- table = gr.Dataframe(get_all_model_names, datatype="markdown")
247
 
248
  with gr.Accordion("📙 Citation", open=False):
249
  citation_button = gr.Textbox(
@@ -257,4 +265,4 @@ Click through the tasks below to see how different models measure up in terms of
257
  """Last updated: February 2025"""
258
  )
259
 
260
- demo.launch()
 
39
 
40
  def get_plots(task):
41
  df = pd.read_csv('data/energy/' + task)
42
+ # Remove extra unnamed column if present
43
+ if df.columns[0].startswith("Unnamed:"):
44
+ df = df.iloc[:, 1:]
45
+
46
+ # Convert energy_score to int and then to str so it's treated as categorical
47
+ df['energy_score'] = df['energy_score'].astype(int).astype(str)
48
  df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1])
49
 
50
+ # Update color_map keys to be strings
51
+ color_map = {"1": "red", "2": "orange", "3": "yellow", "4": "lightgreen", "5": "green"}
52
 
53
  fig = px.scatter(
54
  df,
55
+ x="total_gpu_energy", # x-axis: GPU energy consumption
56
+ y="Display Model", # y-axis: Model name for display
57
+ color="energy_score", # Discrete color based on energy score
58
  custom_data=['energy_score'],
59
  height=500,
60
  width=800,
 
76
  df = pd.read_csv('data/energy/' + task)
77
  if df.columns[0].startswith("Unnamed:"):
78
  df = df.iloc[:, 1:]
79
+ # Convert energy_score to categorical string
80
+ df['energy_score'] = df['energy_score'].astype(int).astype(str)
81
  df['Display Model'] = df['model'].apply(lambda m: m.split('/')[-1])
82
  all_df = pd.concat([all_df, df], ignore_index=True)
83
  all_df = all_df.drop_duplicates(subset=['model'])
84
 
85
+ color_map = {"1": "red", "2": "orange", "3": "yellow", "4": "lightgreen", "5": "green"}
86
 
87
  fig = px.scatter(
88
  all_df,
89
+ x="total_gpu_energy", # x-axis: GPU energy consumption
90
  y="Display Model",
91
+ color="energy_score", # Discrete color mapping
92
  custom_data=['energy_score'],
93
  height=500,
94
  width=800,
 
248
  with gr.TabItem("All Tasks 💡"):
249
  with gr.Row():
250
  with gr.Column():
251
+ # Call the functions to generate the plot and table
252
+ plot = gr.Plot(get_all_plots())
253
  with gr.Column():
254
+ table = gr.Dataframe(get_all_model_names(), datatype="markdown")
255
 
256
  with gr.Accordion("📙 Citation", open=False):
257
  citation_button = gr.Textbox(
 
265
  """Last updated: February 2025"""
266
  )
267
 
268
+ demo.launch()