Spaces:
Sleeping
Sleeping
add cli arg to 🅱️oost 🅱️eams
Browse filesSigned-off-by: peter szemraj <[email protected]>
- aggregate.py +9 -0
- app.py +14 -0
aggregate.py
CHANGED
@@ -179,6 +179,15 @@ class BatchAggregator:
|
|
179 |
|
180 |
self.aggregator.model.generation_config.update(**kwargs)
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
def update_loglevel(self, level: str = "INFO"):
|
183 |
"""
|
184 |
Update the log level.
|
|
|
179 |
|
180 |
self.aggregator.model.generation_config.update(**kwargs)
|
181 |
|
182 |
+
def get_generation_config(self) -> dict:
|
183 |
+
"""
|
184 |
+
Get the current generation configuration.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
dict: The current generation configuration.
|
188 |
+
"""
|
189 |
+
return self.aggregator.model.generation_config.to_dict()
|
190 |
+
|
191 |
def update_loglevel(self, level: str = "INFO"):
|
192 |
"""
|
193 |
Update the log level.
|
app.py
CHANGED
@@ -427,6 +427,14 @@ def parse_args():
|
|
427 |
default=None,
|
428 |
help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
parser.add_argument(
|
431 |
"-level",
|
432 |
"--log_level",
|
@@ -460,6 +468,12 @@ if __name__ == "__main__":
|
|
460 |
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
461 |
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
logger.info("Loading OCR model")
|
464 |
with contextlib.redirect_stdout(None):
|
465 |
ocr_model = ocr_predictor(
|
|
|
427 |
default=None,
|
428 |
help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
429 |
)
|
430 |
+
parser.add_argument(
|
431 |
+
"-max_agg",
|
432 |
+
"-2x",
|
433 |
+
"--aggregator_beam_boost",
|
434 |
+
dest="aggregator_beam_boost",
|
435 |
+
action="store_true",
|
436 |
+
help="Double the number of beams for the aggregator during beam search",
|
437 |
+
)
|
438 |
parser.add_argument(
|
439 |
"-level",
|
440 |
"--log_level",
|
|
|
468 |
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
469 |
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
470 |
|
471 |
+
if args.aggregator_beam_boost:
|
472 |
+
logger.info("Doubling aggregator num_beams")
|
473 |
+
_agg_cfg = aggregator.get_generation_config()
|
474 |
+
_agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
|
475 |
+
aggregator.update_generation_config(**_agg_cfg)
|
476 |
+
|
477 |
logger.info("Loading OCR model")
|
478 |
with contextlib.redirect_stdout(None):
|
479 |
ocr_model = ocr_predictor(
|