uint8 actually slower than bfloat16?

#29
by sulemank - opened

I'm seeing something weird when running the model in 8bit vs bfloat16. On an A100 GPU the model does inference 3x faster with bfloat16 than with uint8.

Here's how I construct the pipeline for bfloat16:

instruct_pipeline = pipeline(
    model="databricks/dolly-v2-12b",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    model_kwargs={"load_in_8bit": False},
)

and for uint8:

instruct_pipeline = pipeline(
    model="databricks/dolly-v2-12b",
    trust_remote_code=True,
    device_map="auto",
    model_kwargs={"load_in_8bit": True
)

Followed by inference with:

instruct_pipeline("Tell me a short story about little red riding hood.")

Anything I'm doing wrong? As this is pretty unexpected.

Databricks org

Yeah, it depends on the hardware, but entirely possible. The integer units on the GPUs are separate from the FP units, so they have different capabilities here. I think the Hopper GPUs are supposed to focus more on int8 math. int8 is about memory size, not speed, in general. Speed seemed similar on an A10 for me.

Thanks, yes it seems hardware dependent as you said. I see the same difference when using FLAN UL2 in uint8 vs bfloat16 modes on an A100

sulemank changed discussion status to closed

Sign up or log in to comment