AppleSwing
commited on
Commit
•
89eec2c
1
Parent(s):
dd01425
Change get GPU name
Browse files- src/utils.py +6 -1
src/utils.py
CHANGED
@@ -214,10 +214,15 @@ def get_gpu_details():
|
|
214 |
gpus = GPUtil.getGPUs()
|
215 |
gpu = gpus[0]
|
216 |
name = gpu.name.replace(" ", "-")
|
217 |
-
# Convert memory from MB to GB and round to nearest whole number
|
218 |
memory_gb = round(gpu.memoryTotal / 1024)
|
219 |
memory = f"{memory_gb}GB"
|
|
|
|
|
|
|
|
|
|
|
220 |
formatted_name = f"{name}-{memory}"
|
|
|
221 |
return formatted_name
|
222 |
|
223 |
def get_peak_bw(gpu_name):
|
|
|
214 |
gpus = GPUtil.getGPUs()
|
215 |
gpu = gpus[0]
|
216 |
name = gpu.name.replace(" ", "-")
|
|
|
217 |
memory_gb = round(gpu.memoryTotal / 1024)
|
218 |
memory = f"{memory_gb}GB"
|
219 |
+
|
220 |
+
for part in name.split('-'):
|
221 |
+
if part.endswith("GB") and part[:-2].isdigit():
|
222 |
+
name = name.replace(f"-{part}", "").replace(part, "")
|
223 |
+
|
224 |
formatted_name = f"{name}-{memory}"
|
225 |
+
|
226 |
return formatted_name
|
227 |
|
228 |
def get_peak_bw(gpu_name):
|