bgamazay commited on
Commit
70db274
·
verified ·
1 Parent(s): 83432dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -52
app.py CHANGED
@@ -4,13 +4,40 @@ from PIL import Image, ImageDraw, ImageFont
4
  import io
5
 
6
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Sidebar logo and title
8
  with st.sidebar:
9
- col1, col2 = st.columns([1, 5]) # Shrink the logo column and expand the text column
10
 
11
  with col1:
12
  logo = Image.open("logo.png")
13
- resized_logo = logo.resize((40, 40)) # Resize the logo
14
  st.image(resized_logo)
15
 
16
  with col2:
@@ -31,11 +58,8 @@ def main():
31
  unsafe_allow_html=True,
32
  )
33
 
34
- # (Removed the "Generate a Label to Display your AI Energy Score" section)
35
-
36
  st.sidebar.markdown("<hr style='border: 1px solid gray; margin: 15px 0;'>", unsafe_allow_html=True)
37
 
38
- # Update instructions header
39
  st.sidebar.write("### Generate Label:")
40
 
41
  # Define the ordered list of tasks.
@@ -52,8 +76,8 @@ def main():
52
  "Sentence Similarity"
53
  ]
54
 
55
- # Make the task selection label green and remove redundant text.
56
- st.sidebar.markdown('<p style="color: green; font-size: 16px;">Task(s):</p>', unsafe_allow_html=True)
57
  selected_tasks = st.sidebar.multiselect("", options=task_order, default=task_order)
58
 
59
  # Mapping from task to CSV file name.
@@ -64,30 +88,27 @@ def main():
64
  "Image Classification": "image_classification_energyscore.csv",
65
  "Image Captioning": "image_caption_energyscore.csv",
66
  "Summarization": "summarization_energyscore.csv",
67
- "ASR": "asr_energyscore.csv",
68
  "Object Detection": "object_detection_energyscore.csv",
69
  "Question Answering": "question_answering_energyscore.csv",
70
  "Sentence Similarity": "sentence_similarity_energyscore.csv"
71
  }
72
 
73
- # Default placeholder model data.
74
  default_model_data = {
75
- 'provider': "AI Provider",
76
- 'model': "Model Name",
77
- 'full_model': "AI Provider/Model Name",
78
- 'date': "",
79
- 'task': "",
80
- 'hardware': "",
81
- 'energy': "?",
82
- 'score': 5
83
  }
84
 
85
  if not selected_tasks:
86
- # If no tasks are selected, use generic placeholder.
87
  model_data = default_model_data
88
  else:
89
  dfs = []
90
- # Load and process each CSV corresponding to the selected tasks.
91
  for task in selected_tasks:
92
  file_name = task_to_file[task]
93
  try:
@@ -99,17 +120,12 @@ def main():
99
  st.sidebar.error(f"Error reading '{file_name}' for task {task}: {e}")
100
  continue
101
 
102
- # Save the original full model string and then split the "model" column
103
  df['full_model'] = df['model']
104
  df[['provider', 'model']] = df['model'].str.split(pat='/', n=1, expand=True)
105
- # Round total_gpu_energy to 3 decimal places and assign to 'energy'
106
  df['energy'] = df['total_gpu_energy'].round(3)
107
- # Use the energy_score column as 'score' (fill missing values with 1 to avoid casting errors)
108
  df['score'] = df['energy_score'].fillna(1).astype(int)
109
- # Hardcode date and hardware
110
  df['date'] = "February 2025"
111
  df['hardware'] = "NVIDIA H100-80GB"
112
- # Set the task from the file name mapping
113
  df['task'] = task
114
 
115
  dfs.append(df)
@@ -121,7 +137,6 @@ def main():
121
  if data_df.empty:
122
  model_data = default_model_data
123
  else:
124
- # In the scored model dropdown show the full model string.
125
  model_options = data_df["full_model"].unique().tolist()
126
  selected_model = st.sidebar.selectbox(
127
  "Scored Models",
@@ -130,10 +145,9 @@ def main():
130
  )
131
  model_data = data_df[data_df["full_model"] == selected_model].iloc[0]
132
 
133
- st.sidebar.write("#### 2. Select a model below")
134
  st.sidebar.write("#### 3. Download the label")
135
 
136
- # Select background by score (using generic placeholder score=5 if applicable)
137
  try:
138
  score = int(model_data["score"])
139
  background_path = f"{score}.png"
@@ -145,7 +159,6 @@ def main():
145
  st.sidebar.error(f"Invalid score '{model_data['score']}'. Score must be an integer.")
146
  return
147
 
148
- # Keep the final label size at 520×728
149
  final_size = (520, 728)
150
  generated_label = create_label_single_pass(background, model_data, final_size)
151
 
@@ -172,14 +185,9 @@ def main():
172
 
173
 
174
  def create_label_single_pass(background_image, model_data, final_size=(520, 728)):
175
- """
176
- Resizes the background to 520×728, then draws text onto it.
177
- """
178
- # 1. Resize background to final_size
179
  bg_resized = background_image.resize(final_size, Image.Resampling.LANCZOS)
180
  draw = ImageDraw.Draw(bg_resized)
181
 
182
- # 2. Load fonts at sizes appropriate for a 520×728 label
183
  try:
184
  title_font = ImageFont.truetype("Inter_24pt-Bold.ttf", size=27)
185
  details_font = ImageFont.truetype("Inter_18pt-Regular.ttf", size=23)
@@ -188,37 +196,22 @@ def create_label_single_pass(background_image, model_data, final_size=(520, 728)
188
  st.error(f"Font loading failed: {e}")
189
  return bg_resized
190
 
191
- # 3. Place your text.
192
- # Flip the order so that the provider (AI Developer) is shown first and the model name second.
193
  title_x, title_y = 33, 150
194
  details_x, details_y = 480, 256
195
  energy_x, energy_y = 480, 472
196
 
197
- # Text 1 (title) – show provider first then model name
198
  draw.text((title_x, title_y), str(model_data['provider']), font=title_font, fill="black")
199
  draw.text((title_x, title_y + 38), str(model_data['model']), font=title_font, fill="black")
200
 
201
- # Text 2 (details)
202
- details_lines = [
203
- str(model_data['date']),
204
- str(model_data['task']),
205
- str(model_data['hardware'])
206
- ]
207
  for i, line in enumerate(details_lines):
208
  bbox = draw.textbbox((0, 0), line, font=details_font)
209
- text_width = bbox[2] - bbox[0]
210
- # Right-justify the details text at details_x
211
- draw.text((details_x - text_width, details_y + i * 47), line, font=details_font, fill="black")
212
 
213
- # Text 3 (energy)
214
- energy_text = str(model_data['energy'])
215
- bbox = draw.textbbox((0, 0), energy_text, font=energy_font)
216
- energy_text_width = bbox[2] - bbox[0]
217
- # Right-align the energy text at energy_x
218
- draw.text((energy_x - energy_text_width, energy_y), energy_text, font=energy_font, fill="black")
219
 
220
  return bg_resized
221
 
222
 
223
  if __name__ == "__main__":
224
- main()
 
4
  import io
5
 
6
  def main():
7
+ # Inject custom CSS to change the color of selected tasks
8
+ st.markdown(
9
+ """
10
+ <style>
11
+ /* Change background color of selected items */
12
+ .stMultiSelect [data-baseweb="tag"] {
13
+ background-color: #3fa45bff !important; /* Custom green */
14
+ color: white !important; /* White text */
15
+ font-weight: bold;
16
+ border-radius: 5px;
17
+ padding: 5px 10px;
18
+ }
19
+
20
+ /* Change hover effect */
21
+ .stMultiSelect [data-baseweb="tag"]:hover {
22
+ background-color: #358d4d !important;
23
+ }
24
+
25
+ /* Style the dropdown input field */
26
+ .stMultiSelect input {
27
+ color: black !important;
28
+ }
29
+ </style>
30
+ """,
31
+ unsafe_allow_html=True,
32
+ )
33
+
34
  # Sidebar logo and title
35
  with st.sidebar:
36
+ col1, col2 = st.columns([1, 5])
37
 
38
  with col1:
39
  logo = Image.open("logo.png")
40
+ resized_logo = logo.resize((50, 50))
41
  st.image(resized_logo)
42
 
43
  with col2:
 
58
  unsafe_allow_html=True,
59
  )
60
 
 
 
61
  st.sidebar.markdown("<hr style='border: 1px solid gray; margin: 15px 0;'>", unsafe_allow_html=True)
62
 
 
63
  st.sidebar.write("### Generate Label:")
64
 
65
  # Define the ordered list of tasks.
 
76
  "Sentence Similarity"
77
  ]
78
 
79
+ # Task selection
80
+ st.sidebar.write("#### 1. Select task(s) to view models")
81
  selected_tasks = st.sidebar.multiselect("", options=task_order, default=task_order)
82
 
83
  # Mapping from task to CSV file name.
 
88
  "Image Classification": "image_classification_energyscore.csv",
89
  "Image Captioning": "image_caption_energyscore.csv",
90
  "Summarization": "summarization_energyscore.csv",
91
+ "Speech-to-Text (ASR)": "asr_energyscore.csv",
92
  "Object Detection": "object_detection_energyscore.csv",
93
  "Question Answering": "question_answering_energyscore.csv",
94
  "Sentence Similarity": "sentence_similarity_energyscore.csv"
95
  }
96
 
 
97
  default_model_data = {
98
+ 'provider': "AI Provider",
99
+ 'model': "Model Name",
100
+ 'full_model': "AI Provider/Model Name",
101
+ 'date': "",
102
+ 'task': "",
103
+ 'hardware': "",
104
+ 'energy': "?",
105
+ 'score': 5
106
  }
107
 
108
  if not selected_tasks:
 
109
  model_data = default_model_data
110
  else:
111
  dfs = []
 
112
  for task in selected_tasks:
113
  file_name = task_to_file[task]
114
  try:
 
120
  st.sidebar.error(f"Error reading '{file_name}' for task {task}: {e}")
121
  continue
122
 
 
123
  df['full_model'] = df['model']
124
  df[['provider', 'model']] = df['model'].str.split(pat='/', n=1, expand=True)
 
125
  df['energy'] = df['total_gpu_energy'].round(3)
 
126
  df['score'] = df['energy_score'].fillna(1).astype(int)
 
127
  df['date'] = "February 2025"
128
  df['hardware'] = "NVIDIA H100-80GB"
 
129
  df['task'] = task
130
 
131
  dfs.append(df)
 
137
  if data_df.empty:
138
  model_data = default_model_data
139
  else:
 
140
  model_options = data_df["full_model"].unique().tolist()
141
  selected_model = st.sidebar.selectbox(
142
  "Scored Models",
 
145
  )
146
  model_data = data_df[data_df["full_model"] == selected_model].iloc[0]
147
 
148
+ st.sidebar.write("#### 2. Select a model to generate label")
149
  st.sidebar.write("#### 3. Download the label")
150
 
 
151
  try:
152
  score = int(model_data["score"])
153
  background_path = f"{score}.png"
 
159
  st.sidebar.error(f"Invalid score '{model_data['score']}'. Score must be an integer.")
160
  return
161
 
 
162
  final_size = (520, 728)
163
  generated_label = create_label_single_pass(background, model_data, final_size)
164
 
 
185
 
186
 
187
  def create_label_single_pass(background_image, model_data, final_size=(520, 728)):
 
 
 
 
188
  bg_resized = background_image.resize(final_size, Image.Resampling.LANCZOS)
189
  draw = ImageDraw.Draw(bg_resized)
190
 
 
191
  try:
192
  title_font = ImageFont.truetype("Inter_24pt-Bold.ttf", size=27)
193
  details_font = ImageFont.truetype("Inter_18pt-Regular.ttf", size=23)
 
196
  st.error(f"Font loading failed: {e}")
197
  return bg_resized
198
 
 
 
199
  title_x, title_y = 33, 150
200
  details_x, details_y = 480, 256
201
  energy_x, energy_y = 480, 472
202
 
 
203
  draw.text((title_x, title_y), str(model_data['provider']), font=title_font, fill="black")
204
  draw.text((title_x, title_y + 38), str(model_data['model']), font=title_font, fill="black")
205
 
206
+ details_lines = [str(model_data['date']), str(model_data['task']), str(model_data['hardware'])]
 
 
 
 
 
207
  for i, line in enumerate(details_lines):
208
  bbox = draw.textbbox((0, 0), line, font=details_font)
209
+ draw.text((details_x - bbox[2], details_y + i * 47), line, font=details_font, fill="black")
 
 
210
 
211
+ draw.text((energy_x, energy_y), str(model_data['energy']), font=energy_font, fill="black")
 
 
 
 
 
212
 
213
  return bg_resized
214
 
215
 
216
  if __name__ == "__main__":
217
+ main()