Flash attention NVCC requirements

#2
by jdjayakaran - opened

Is there a way to run the code without nvcc, as modeling_flash_llama.py uses flash attention which needs nvcc. But the same is not the requirements on the licensed LLAMA2 model.

jdjayakaran changed discussion status to closed
jdjayakaran changed discussion status to open
jdjayakaran changed discussion status to closed
LAION LeoLM org

Seems you've solved this yourself? For future reference, running with trust_remote_code=False disables the use of custom flash-attention and should work on any hardware. Most recently, transformers has integrated support for flash-attention-2 with use_flash_attention_2=True. Hope this is helpful :)

Yes thanks seems to work. I would check flash attention for faster inference.

Also is there a way I could check the perplexity of the model to compare its performance in different use cases?

jdjayakaran changed discussion status to open

There will be more detailed perplexity evaluations in our paper. For now though, I can share some code that should help you get started with perplexity evaluation. If you don't want to implement it yourself, check out something like text-generation-webui for prebuilt perplexity evals.

# Compute perplexity
nlls = []
for sample in dataset:
    with torch.no_grad():
        input_ids = torch.tensor(sample ['input_ids']).unsqueeze(0).to(model.device)
        attention_mask = torch.tensor(sample ['attention_mask']).unsqueeze(0).to(model.device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.clone())
        loss = outputs.loss
        nlls.append(loss.cpu().item())
nlls = torch.tensor(nlls)
perplexity = torch.exp(nlls.mean())

dataset here is a pretokenized dataset iterator. Hope this is helpful :)

bjoernp changed discussion status to closed

Sign up or log in to comment