Gogryu commited on
Commit
44e1649
·
1 Parent(s): ed27476
Files changed (1) hide show
  1. src/pages/Calculator.tsx +363 -58
src/pages/Calculator.tsx CHANGED
@@ -614,72 +614,73 @@ const PrefillChunkingCalculator = ({
614
  )
615
 
616
  return (
 
617
  <div>
618
- <div className='text-2xl'>Prefill Chunking Calculator</div>
619
- <div className='chart'>
620
- <div className='flex flex-col items'>
621
  <div className='text-2xl'>Model Footprint with Prefill Chunking</div>
622
  </div>
623
- <div className='chart-row my-8'>
624
- <div className='chart-row-title'>FP32</div>
625
- <PrefillChunkingModelSizeBarChart
626
- modelSize={calculateMemory(modelParams, 'fp32')}
627
- largestModelSize={deviceMemory || calculateMemory(modelParams, 'fp32')}
628
- modelPrecision='fp32'
629
- deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
630
- activationMemorySize={activationMemorySize}
631
- />
632
- <div className='chart-row-size ml-8'>
633
- {(calculateMemory(modelParams, 'fp32') + activationMemorySize).toFixed(2)}{' '}
634
- {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
 
 
635
  </div>
636
- </div>
637
 
638
- <div className='chart-row my-8'>
639
- <div className='chart-row-title'>FP16</div>
640
- <PrefillChunkingModelSizeBarChart
641
- modelSize={calculateMemory(modelParams, 'fp16')}
642
- largestModelSize={deviceMemory || calculateMemory(modelParams, 'fp16')}
643
- modelPrecision='fp16'
644
- deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
645
- activationMemorySize={activationMemorySize}
646
- />
647
- <div className='chart-row-size ml-8'>
648
- {(calculateMemory(modelParams, 'fp16') + activationMemorySize).toFixed(2)}{' '}
649
- {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
 
650
  </div>
651
- </div>
652
 
653
- <div className='chart-row my-8'>
654
- <div className='chart-row-title'>INT8</div>
655
- <PrefillChunkingModelSizeBarChart
656
- modelSize={calculateMemory(modelParams, 'int8')}
657
- largestModelSize={deviceMemory || calculateMemory(modelParams, 'int8')}
658
- modelPrecision='int8'
659
- deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
660
- activationMemorySize={activationMemorySize}
661
- />
662
- <div className='chart-row-size ml-8'>
663
- {(calculateMemory(modelParams, 'int8') + activationMemorySize).toFixed(2)}{' '}
664
- {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
 
665
  </div>
666
- </div>
667
 
668
- <div className='chart-row my-8'>
669
- <div className='chart-row-title'>INT4</div>
670
- <PrefillChunkingModelSizeBarChart
671
- modelSize={calculateMemory(modelParams, 'int4')}
672
- largestModelSize={deviceMemory || calculateMemory(modelParams, 'int4')}
673
- modelPrecision='int4'
674
- deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
675
- activationMemorySize={activationMemorySize}
676
- />
677
- <div className='chart-row-size ml-8'>
678
- {(calculateMemory(modelParams, 'int4') + activationMemorySize).toFixed(2)}{' '}
679
- {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
 
680
  </div>
681
  </div>
682
-
683
  </div>
684
  <div className='chart'>
685
  <div className='flex flex-col items-center'>
@@ -1094,7 +1095,7 @@ const Calculator = () => {
1094
  {/* Maximum Batch Size / Sequence Length Chart */}
1095
  <div className="chart mb-8">
1096
  <div className="text-2xl text-center mb-4">Maximum Batch Size / Sequence Length</div>
1097
- <div className="flex flex-col items-center">
1098
  <InferenceRuntimeLineChart
1099
  availableMemory={{
1100
  int4: deviceMemory - calculateMemory(modelParams, 'int4'),
@@ -1104,8 +1105,312 @@ const Calculator = () => {
1104
  }}
1105
  memoryPerInput={calculateMemoryPerInput(hiddenSize, numLayers)}
1106
  />
1107
- <div className="chart-side-panel ml-4 pt-4 w-full max-w-xs">
1108
- {/* Batch Size and Sequence Length Inputs */}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1109
  </div>
1110
  </div>
1111
  </div>
 
614
  )
615
 
616
  return (
617
+
618
  <div>
619
+ <div className='chart mb-8'>
620
+ <div className='flex flex-col items-center'>
 
621
  <div className='text-2xl'>Model Footprint with Prefill Chunking</div>
622
  </div>
623
+ <div>
624
+ <div className='chart-row'>
625
+ <div className='chart-row-title'>FP32</div>
626
+ <PrefillChunkingModelSizeBarChart
627
+ modelSize={calculateMemory(modelParams, 'fp32')}
628
+ largestModelSize={deviceMemory || calculateMemory(modelParams, 'fp32')}
629
+ modelPrecision='fp32'
630
+ deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
631
+ activationMemorySize={activationMemorySize}
632
+ />
633
+ <div className='chart-row-size ml-8'>
634
+ {(calculateMemory(modelParams, 'fp32') + activationMemorySize).toFixed(2)}{' '}
635
+ {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
636
+ </div>
637
  </div>
 
638
 
639
+ <div className='chart-row my-8'>
640
+ <div className='chart-row-title'>FP16</div>
641
+ <PrefillChunkingModelSizeBarChart
642
+ modelSize={calculateMemory(modelParams, 'fp16')}
643
+ largestModelSize={deviceMemory || calculateMemory(modelParams, 'fp16')}
644
+ modelPrecision='fp16'
645
+ deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
646
+ activationMemorySize={activationMemorySize}
647
+ />
648
+ <div className='chart-row-size ml-8'>
649
+ {(calculateMemory(modelParams, 'fp16') + activationMemorySize).toFixed(2)}{' '}
650
+ {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
651
+ </div>
652
  </div>
 
653
 
654
+ <div className='chart-row my-8'>
655
+ <div className='chart-row-title'>INT8</div>
656
+ <PrefillChunkingModelSizeBarChart
657
+ modelSize={calculateMemory(modelParams, 'int8')}
658
+ largestModelSize={deviceMemory || calculateMemory(modelParams, 'int8')}
659
+ modelPrecision='int8'
660
+ deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
661
+ activationMemorySize={activationMemorySize}
662
+ />
663
+ <div className='chart-row-size ml-8'>
664
+ {(calculateMemory(modelParams, 'int8') + activationMemorySize).toFixed(2)}{' '}
665
+ {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
666
+ </div>
667
  </div>
 
668
 
669
+ <div className='chart-row my-8'>
670
+ <div className='chart-row-title'>INT4</div>
671
+ <PrefillChunkingModelSizeBarChart
672
+ modelSize={calculateMemory(modelParams, 'int4')}
673
+ largestModelSize={deviceMemory || calculateMemory(modelParams, 'int4')}
674
+ modelPrecision='int4'
675
+ deviceMemorySet={deviceMemory !== null && deviceMemory > 0}
676
+ activationMemorySize={activationMemorySize}
677
+ />
678
+ <div className='chart-row-size ml-8'>
679
+ {(calculateMemory(modelParams, 'int4') + activationMemorySize).toFixed(2)}{' '}
680
+ {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
681
+ </div>
682
  </div>
683
  </div>
 
684
  </div>
685
  <div className='chart'>
686
  <div className='flex flex-col items-center'>
 
1095
  {/* Maximum Batch Size / Sequence Length Chart */}
1096
  <div className="chart mb-8">
1097
  <div className="text-2xl text-center mb-4">Maximum Batch Size / Sequence Length</div>
1098
+ <div className="flex flex-row items-left">
1099
  <InferenceRuntimeLineChart
1100
  availableMemory={{
1101
  int4: deviceMemory - calculateMemory(modelParams, 'int4'),
 
1105
  }}
1106
  memoryPerInput={calculateMemoryPerInput(hiddenSize, numLayers)}
1107
  />
1108
+ <div className="chart-side-panel ml-4 pt-4">
1109
+ <div className='mb-2'>
1110
+ Memory/token:{' '}
1111
+ {(calculateMemoryPerInput(hiddenSize, numLayers) * 1_000_000).toFixed(0)} KB
1112
+ </div>
1113
+ <label htmlFor='batchSize'>Batch Size</label>
1114
+ <input
1115
+ type='number'
1116
+ id='batchSize'
1117
+ className='side-panel-input mb-2'
1118
+ value={batchSize || ''}
1119
+ min={1}
1120
+ onChange={(e) => setBatchSize(Number(e.target.value))}
1121
+ />
1122
+ <label htmlFor='seqLength'>Sequence Length</label>
1123
+ <input
1124
+ type='number'
1125
+ id='seqLength'
1126
+ className='side-panel-input'
1127
+ value={seqLength || ''}
1128
+ min={1}
1129
+ onChange={(e) => setSeqLength(Number(e.target.value))}
1130
+ />
1131
+ <div className='mt-4'>
1132
+ {!batchSize && !seqLength ? (
1133
+ <div>
1134
+ Input a batch size or sequence length to see the maximum batch size or
1135
+ sequence length you can run on your device.
1136
+ </div>
1137
+ ) : null}
1138
+ {batchSize && !seqLength ? (
1139
+ <>
1140
+ <div>Max Sequence Lengths:</div>
1141
+ <div>
1142
+ FP32:{' '}
1143
+ <strong>
1144
+ {calculateMaxInputSize(
1145
+ deviceMemory,
1146
+ modelParams,
1147
+ hiddenSize,
1148
+ numLayers,
1149
+ 'fp32',
1150
+ batchSize,
1151
+ ) > 0
1152
+ ? calculateMaxInputSize(
1153
+ deviceMemory,
1154
+ modelParams,
1155
+ hiddenSize,
1156
+ numLayers,
1157
+ 'fp32',
1158
+ batchSize,
1159
+ )
1160
+ : 'Out of Memory'}
1161
+ </strong>
1162
+ </div>
1163
+ <div>
1164
+ FP16:{' '}
1165
+ <strong>
1166
+ {calculateMaxInputSize(
1167
+ deviceMemory,
1168
+ modelParams,
1169
+ hiddenSize,
1170
+ numLayers,
1171
+ 'fp16',
1172
+ batchSize,
1173
+ ) > 0
1174
+ ? calculateMaxInputSize(
1175
+ deviceMemory,
1176
+ modelParams,
1177
+ hiddenSize,
1178
+ numLayers,
1179
+ 'fp16',
1180
+ batchSize,
1181
+ )
1182
+ : 'Out of Memory'}
1183
+ </strong>
1184
+ </div>
1185
+ <div>
1186
+ INT8:{' '}
1187
+ <strong>
1188
+ {calculateMaxInputSize(
1189
+ deviceMemory,
1190
+ modelParams,
1191
+ hiddenSize,
1192
+ numLayers,
1193
+ 'int8',
1194
+ batchSize,
1195
+ ) > 0
1196
+ ? calculateMaxInputSize(
1197
+ deviceMemory,
1198
+ modelParams,
1199
+ hiddenSize,
1200
+ numLayers,
1201
+ 'int8',
1202
+ batchSize,
1203
+ )
1204
+ : 'Out of Memory'}
1205
+ </strong>
1206
+ </div>
1207
+ <div>
1208
+ INT4:{' '}
1209
+ <strong>
1210
+ {calculateMaxInputSize(
1211
+ deviceMemory,
1212
+ modelParams,
1213
+ hiddenSize,
1214
+ numLayers,
1215
+ 'int4',
1216
+ batchSize,
1217
+ ) > 0
1218
+ ? calculateMaxInputSize(
1219
+ deviceMemory,
1220
+ modelParams,
1221
+ hiddenSize,
1222
+ numLayers,
1223
+ 'int4',
1224
+ batchSize,
1225
+ )
1226
+ : 'Out of Memory'}
1227
+ </strong>
1228
+ </div>
1229
+ </>
1230
+ ) : null}
1231
+ {!batchSize && seqLength ? (
1232
+ <>
1233
+ <div>Max Batch Sizes:</div>
1234
+ <div>
1235
+ FP32:{' '}
1236
+ <strong>
1237
+ {calculateMaxInputSize(
1238
+ deviceMemory,
1239
+ modelParams,
1240
+ hiddenSize,
1241
+ numLayers,
1242
+ 'fp32',
1243
+ seqLength,
1244
+ ) > 0
1245
+ ? calculateMaxInputSize(
1246
+ deviceMemory,
1247
+ modelParams,
1248
+ hiddenSize,
1249
+ numLayers,
1250
+ 'fp32',
1251
+ seqLength,
1252
+ )
1253
+ : 'Out of Memory'}
1254
+ </strong>
1255
+ </div>
1256
+ <div>
1257
+ FP16:{' '}
1258
+ <strong>
1259
+ {calculateMaxInputSize(
1260
+ deviceMemory,
1261
+ modelParams,
1262
+ hiddenSize,
1263
+ numLayers,
1264
+ 'fp16',
1265
+ seqLength,
1266
+ ) > 0
1267
+ ? calculateMaxInputSize(
1268
+ deviceMemory,
1269
+ modelParams,
1270
+ hiddenSize,
1271
+ numLayers,
1272
+ 'fp16',
1273
+ seqLength,
1274
+ )
1275
+ : 'Out of Memory'}
1276
+ </strong>
1277
+ </div>
1278
+ <div>
1279
+ INT8:{' '}
1280
+ <strong>
1281
+ {calculateMaxInputSize(
1282
+ deviceMemory,
1283
+ modelParams,
1284
+ hiddenSize,
1285
+ numLayers,
1286
+ 'int8',
1287
+ seqLength,
1288
+ ) > 0
1289
+ ? calculateMaxInputSize(
1290
+ deviceMemory,
1291
+ modelParams,
1292
+ hiddenSize,
1293
+ numLayers,
1294
+ 'int8',
1295
+ seqLength,
1296
+ )
1297
+ : 'Out of Memory'}
1298
+ </strong>
1299
+ </div>
1300
+ <div>
1301
+ INT4:{' '}
1302
+ <strong>
1303
+ {calculateMaxInputSize(
1304
+ deviceMemory,
1305
+ modelParams,
1306
+ hiddenSize,
1307
+ numLayers,
1308
+ 'int4',
1309
+ seqLength,
1310
+ ) > 0
1311
+ ? calculateMaxInputSize(
1312
+ deviceMemory,
1313
+ modelParams,
1314
+ hiddenSize,
1315
+ numLayers,
1316
+ 'int4',
1317
+ seqLength,
1318
+ )
1319
+ : 'Out of Memory'}
1320
+ </strong>
1321
+ </div>
1322
+ </>
1323
+ ) : null}
1324
+ {batchSize && seqLength ? (
1325
+ <>
1326
+ <div>Total Memory Usage:</div>
1327
+ <div>
1328
+ FP32:{' '}
1329
+ <strong>
1330
+ {calculateMemoryValid(
1331
+ deviceMemory,
1332
+ modelParams,
1333
+ hiddenSize,
1334
+ numLayers,
1335
+ 'fp32',
1336
+ batchSize,
1337
+ seqLength,
1338
+ )
1339
+ ? (
1340
+ calculateMemory(modelParams, 'fp32') +
1341
+ calculateMemoryPerInput(hiddenSize, numLayers) *
1342
+ batchSize *
1343
+ seqLength
1344
+ ).toFixed(2) + ' GB'
1345
+ : 'Out of Memory'}
1346
+ </strong>
1347
+ </div>
1348
+ <div>
1349
+ FP16:{' '}
1350
+ <strong>
1351
+ {calculateMemoryValid(
1352
+ deviceMemory,
1353
+ modelParams,
1354
+ hiddenSize,
1355
+ numLayers,
1356
+ 'fp16',
1357
+ batchSize,
1358
+ seqLength,
1359
+ )
1360
+ ? (
1361
+ calculateMemory(modelParams, 'fp16') +
1362
+ calculateMemoryPerInput(hiddenSize, numLayers) *
1363
+ batchSize *
1364
+ seqLength
1365
+ ).toFixed(2) + ' GB'
1366
+ : 'Out of Memory'}
1367
+ </strong>
1368
+ </div>
1369
+ <div>
1370
+ INT8:{' '}
1371
+ <strong>
1372
+ {calculateMemoryValid(
1373
+ deviceMemory,
1374
+ modelParams,
1375
+ hiddenSize,
1376
+ numLayers,
1377
+ 'int8',
1378
+ batchSize,
1379
+ seqLength,
1380
+ )
1381
+ ? (
1382
+ calculateMemory(modelParams, 'int8') +
1383
+ calculateMemoryPerInput(hiddenSize, numLayers) *
1384
+ batchSize *
1385
+ seqLength
1386
+ ).toFixed(2) + ' GB'
1387
+ : 'Out of Memory'}
1388
+ </strong>
1389
+ </div>
1390
+ <div>
1391
+ INT4:{' '}
1392
+ <strong>
1393
+ {calculateMemoryValid(
1394
+ deviceMemory,
1395
+ modelParams,
1396
+ hiddenSize,
1397
+ numLayers,
1398
+ 'int4',
1399
+ batchSize,
1400
+ seqLength,
1401
+ )
1402
+ ? (
1403
+ calculateMemory(modelParams, 'int4') +
1404
+ calculateMemoryPerInput(hiddenSize, numLayers) *
1405
+ batchSize *
1406
+ seqLength
1407
+ ).toFixed(2) + ' GB'
1408
+ : 'Out of Memory'}
1409
+ </strong>
1410
+ </div>
1411
+ </>
1412
+ ) : null}
1413
+ </div>
1414
  </div>
1415
  </div>
1416
  </div>