Spaces:
Running
Running
File size: 1,920 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from omegaconf import OmegaConf
import os
from huggingface_hub import login
from starvector.validation.svg_validator_base import validator_registry
def get_validator(validator_name, config):
# Map short engine names to full validator class names
ENGINE_MAPPING = {
'vllm': 'StarVectorVLLMValidator',
'vllm-api': 'StarVectorVLLMAPIValidator',
'hf': 'StarVectorHFSVGValidator'
}
if config.model.generation_engine.lower() in ENGINE_MAPPING:
config.model.generation_engine = ENGINE_MAPPING[config.model.generation_engine.lower()]
# Initialize validator
validator_name = config.model.generation_engine
validator_class = validator_registry.get(validator_name)
if not validator_class:
available_validators = list(validator_registry.keys())
raise ValueError(
f"Validator '{validator_name}' is not recognized. "
f"Available validators: {available_validators}. "
f"You can use short names: {list(ENGINE_MAPPING.keys())}"
)
print(f"Validating with {validator_name}...")
return validator_class
def main(config):
validator_class = get_validator(config.model.generation_engine, config)
validator = validator_class(config)
print(f"Config: {config}")
print(f"Saving in {validator.out_dir}")
validator.validate()
if __name__ == "__main__":
cli_conf = OmegaConf.from_cli()
if 'config' not in cli_conf:
raise ValueError("No config file provided. Please provide a config file using 'config=path/to/config.yaml'")
config_path = cli_conf.pop('config')
config = OmegaConf.load(config_path)
config = OmegaConf.merge(config, cli_conf)
# Login to HuggingFace
# HF_TOKEN = os.getenv('HF_TOKEN')
# if HF_TOKEN is None:
# raise ValueError("HF_TOKEN environment variable is not set.")
# login(token=HF_TOKEN)
main(config)
|