Spaces:
Running
Running
from omegaconf import OmegaConf | |
import os | |
from huggingface_hub import login | |
from starvector.validation.svg_validator_base import validator_registry | |
def get_validator(validator_name, config): | |
# Map short engine names to full validator class names | |
ENGINE_MAPPING = { | |
'vllm': 'StarVectorVLLMValidator', | |
'vllm-api': 'StarVectorVLLMAPIValidator', | |
'hf': 'StarVectorHFSVGValidator' | |
} | |
if config.model.generation_engine.lower() in ENGINE_MAPPING: | |
config.model.generation_engine = ENGINE_MAPPING[config.model.generation_engine.lower()] | |
# Initialize validator | |
validator_name = config.model.generation_engine | |
validator_class = validator_registry.get(validator_name) | |
if not validator_class: | |
available_validators = list(validator_registry.keys()) | |
raise ValueError( | |
f"Validator '{validator_name}' is not recognized. " | |
f"Available validators: {available_validators}. " | |
f"You can use short names: {list(ENGINE_MAPPING.keys())}" | |
) | |
print(f"Validating with {validator_name}...") | |
return validator_class | |
def main(config): | |
validator_class = get_validator(config.model.generation_engine, config) | |
validator = validator_class(config) | |
print(f"Config: {config}") | |
print(f"Saving in {validator.out_dir}") | |
validator.validate() | |
if __name__ == "__main__": | |
cli_conf = OmegaConf.from_cli() | |
if 'config' not in cli_conf: | |
raise ValueError("No config file provided. Please provide a config file using 'config=path/to/config.yaml'") | |
config_path = cli_conf.pop('config') | |
config = OmegaConf.load(config_path) | |
config = OmegaConf.merge(config, cli_conf) | |
# Login to HuggingFace | |
# HF_TOKEN = os.getenv('HF_TOKEN') | |
# if HF_TOKEN is None: | |
# raise ValueError("HF_TOKEN environment variable is not set.") | |
# login(token=HF_TOKEN) | |
main(config) | |