Does the model support beam search for ASR?

#31
by h9LtLSb - opened

Thanks for the great work and open-source release! I tried adding num_beam=5 in the model generation here, but I got an error. Is beam search supported and how could we enable it?

What errors did you get? It should work with huggingface generation function.

Here is the full error:

  File "/data/sls/u/meng/roudi/open_asr_leaderboard/phi/run_eval.py", line 80, in benchmark
    pred_ids = model.generate(
               ^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/transformers/generation/utils.py", line 2286, in generate
    result = self._beam_search(
             ^^^^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/transformers/generation/utils.py", line 3506, in _beam_search
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sls/scratch/roudi/hf_cache/modules/transformers_modules/microsoft/Phi-4-multimodal-instruct/607bf62a754018e31fb4b55abbc7d72cce4ffee5/modeling_phi4mm.py", line 2099, in forward
    assert len(input_mode) == 1```
Microsoft org
edited 1 day ago

@h9LtLSb May I ask how did you add num_beams to model.generate()? Looks like an unsupported value of input_mode (specified in inputs here) is passed over to model.generate()

It can be added like this

            pred_ids = model.generate(
                **inputs,
                num_beams=5,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                **gen_kwargs,
                min_new_tokens=min_new_tokens,

Or we can modify the generation config:
gen_kwargs = {"max_new_tokens": args.max_new_tokens, "num_beams": 5}

            pred_ids = model.generate(
                **inputs,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                **gen_kwargs,
                min_new_tokens=min_new_tokens,

Note: the first way is valid since **kwargs passed to generate matching the attributes of generation_config will override them (from the docs).

Both ways lead to same error.

Microsoft org

@h9LtLSb I see. I found the cause: when num_beams > 1, model.generate() will repeat all inputs including input_mode num_beams times here , which makes the assertion fail. We will push a fix shortly. Thanks a lot for catching it!

Sign up or log in to comment