Does the model support beam search for ASR?
#31
by
h9LtLSb
- opened
What errors did you get? It should work with huggingface generation function.
Here is the full error:
File "/data/sls/u/meng/roudi/open_asr_leaderboard/phi/run_eval.py", line 80, in benchmark
pred_ids = model.generate(
^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/transformers/generation/utils.py", line 2286, in generate
result = self._beam_search(
^^^^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/transformers/generation/utils.py", line 3506, in _beam_search
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/miniconda3/envs/phi312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/sls/scratch/roudi/hf_cache/modules/transformers_modules/microsoft/Phi-4-multimodal-instruct/607bf62a754018e31fb4b55abbc7d72cce4ffee5/modeling_phi4mm.py", line 2099, in forward
assert len(input_mode) == 1```
It can be added like this
pred_ids = model.generate(
**inputs,
num_beams=5,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
**gen_kwargs,
min_new_tokens=min_new_tokens,
Or we can modify the generation config:gen_kwargs = {"max_new_tokens": args.max_new_tokens, "num_beams": 5}
pred_ids = model.generate(
**inputs,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
**gen_kwargs,
min_new_tokens=min_new_tokens,
Note: the first way is valid since **kwargs passed to generate matching the attributes of generation_config will override them
(from the docs).
Both ways lead to same error.