AppleSwing commited on
Commit
89eec2c
1 Parent(s): dd01425

Change get GPU name

Browse files
Files changed (1) hide show
  1. src/utils.py +6 -1
src/utils.py CHANGED
@@ -214,10 +214,15 @@ def get_gpu_details():
214
  gpus = GPUtil.getGPUs()
215
  gpu = gpus[0]
216
  name = gpu.name.replace(" ", "-")
217
- # Convert memory from MB to GB and round to nearest whole number
218
  memory_gb = round(gpu.memoryTotal / 1024)
219
  memory = f"{memory_gb}GB"
 
 
 
 
 
220
  formatted_name = f"{name}-{memory}"
 
221
  return formatted_name
222
 
223
  def get_peak_bw(gpu_name):
 
214
  gpus = GPUtil.getGPUs()
215
  gpu = gpus[0]
216
  name = gpu.name.replace(" ", "-")
 
217
  memory_gb = round(gpu.memoryTotal / 1024)
218
  memory = f"{memory_gb}GB"
219
+
220
+ for part in name.split('-'):
221
+ if part.endswith("GB") and part[:-2].isdigit():
222
+ name = name.replace(f"-{part}", "").replace(part, "")
223
+
224
  formatted_name = f"{name}-{memory}"
225
+
226
  return formatted_name
227
 
228
  def get_peak_bw(gpu_name):