hz2475's picture
init
72f684c
from omegaconf import OmegaConf
import os
from huggingface_hub import login
from starvector.validation.svg_validator_base import validator_registry
def get_validator(validator_name, config):
# Map short engine names to full validator class names
ENGINE_MAPPING = {
'vllm': 'StarVectorVLLMValidator',
'vllm-api': 'StarVectorVLLMAPIValidator',
'hf': 'StarVectorHFSVGValidator'
}
if config.model.generation_engine.lower() in ENGINE_MAPPING:
config.model.generation_engine = ENGINE_MAPPING[config.model.generation_engine.lower()]
# Initialize validator
validator_name = config.model.generation_engine
validator_class = validator_registry.get(validator_name)
if not validator_class:
available_validators = list(validator_registry.keys())
raise ValueError(
f"Validator '{validator_name}' is not recognized. "
f"Available validators: {available_validators}. "
f"You can use short names: {list(ENGINE_MAPPING.keys())}"
)
print(f"Validating with {validator_name}...")
return validator_class
def main(config):
validator_class = get_validator(config.model.generation_engine, config)
validator = validator_class(config)
print(f"Config: {config}")
print(f"Saving in {validator.out_dir}")
validator.validate()
if __name__ == "__main__":
cli_conf = OmegaConf.from_cli()
if 'config' not in cli_conf:
raise ValueError("No config file provided. Please provide a config file using 'config=path/to/config.yaml'")
config_path = cli_conf.pop('config')
config = OmegaConf.load(config_path)
config = OmegaConf.merge(config, cli_conf)
# Login to HuggingFace
# HF_TOKEN = os.getenv('HF_TOKEN')
# if HF_TOKEN is None:
# raise ValueError("HF_TOKEN environment variable is not set.")
# login(token=HF_TOKEN)
main(config)