# stdlib imports from datetime import datetime import re from typing import Optional, Literal, Any import yaml from dataclasses import dataclass from collections import defaultdict # third party imports import click import rich # local imports from ... import Client @dataclass class ModelYamlInfo: model_name: str model_params: dict[str, Any] model_info: dict[str, Any] model_id: str access_groups: list[str] provider: str @property def access_groups_str(self) -> str: return ", ".join(self.access_groups) if self.access_groups else "" def _get_model_info_obj_from_yaml(model: dict[str, Any]) -> ModelYamlInfo: """Extract model info from a model dict and return as ModelYamlInfo dataclass.""" model_name: str = model["model_name"] model_params: dict[str, Any] = model["litellm_params"] model_info: dict[str, Any] = model.get("model_info", {}) model_id: str = model_params["model"] access_groups = model_info.get("access_groups", []) provider = model_id.split("/", 1)[0] if "/" in model_id else model_id return ModelYamlInfo( model_name=model_name, model_params=model_params, model_info=model_info, model_id=model_id, access_groups=access_groups, provider=provider, ) def format_iso_datetime_str(iso_datetime_str: Optional[str]) -> str: """Format an ISO format datetime string to human-readable date with minute resolution.""" if not iso_datetime_str: return "" try: # Parse ISO format datetime string dt = datetime.fromisoformat(iso_datetime_str.replace("Z", "+00:00")) return dt.strftime("%Y-%m-%d %H:%M") except (TypeError, ValueError): return str(iso_datetime_str) def format_timestamp(timestamp: Optional[int]) -> str: """Format a Unix timestamp (integer) to human-readable date with minute resolution.""" if timestamp is None: return "" try: dt = datetime.fromtimestamp(timestamp) return dt.strftime("%Y-%m-%d %H:%M") except (TypeError, ValueError): return str(timestamp) def format_cost_per_1k_tokens(cost: Optional[float]) -> str: """Format a per-token cost to cost per 1000 tokens.""" if cost is None: return "" try: # Convert string to float if needed cost_float = float(cost) # Multiply by 1000 and format to 4 decimal places return f"${cost_float * 1000:.4f}" except (TypeError, ValueError): return str(cost) def create_client(ctx: click.Context) -> Client: """Helper function to create a client from context.""" return Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]) @click.group() def models() -> None: """Manage models on your LiteLLM proxy server""" pass @models.command("list") @click.option( "--format", "output_format", type=click.Choice(["table", "json"]), default="table", help="Output format (table or json)", ) @click.pass_context def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> None: """List all available models""" client = create_client(ctx) models_list = client.models.list() assert isinstance(models_list, list) if output_format == "json": rich.print_json(data=models_list) else: # table format table = rich.table.Table(title="Available Models") # Add columns based on the data structure table.add_column("ID", style="cyan") table.add_column("Object", style="green") table.add_column("Created", style="magenta") table.add_column("Owned By", style="yellow") # Add rows for model in models_list: created = model.get("created") # Convert string timestamp to integer if needed if isinstance(created, str) and created.isdigit(): created = int(created) table.add_row( str(model.get("id", "")), str(model.get("object", "model")), format_timestamp(created) if isinstance(created, int) else format_iso_datetime_str(created), str(model.get("owned_by", "")), ) rich.print(table) @models.command("add") @click.argument("model-name") @click.option( "--param", "-p", multiple=True, help="Model parameters in key=value format (can be specified multiple times)", ) @click.option( "--info", "-i", multiple=True, help="Model info in key=value format (can be specified multiple times)", ) @click.pass_context def add_model(ctx: click.Context, model_name: str, param: tuple[str, ...], info: tuple[str, ...]) -> None: """Add a new model to the proxy""" # Convert parameters from key=value format to dict model_params = dict(p.split("=", 1) for p in param) model_info = dict(i.split("=", 1) for i in info) if info else None client = create_client(ctx) result = client.models.new( model_name=model_name, model_params=model_params, model_info=model_info, ) rich.print_json(data=result) @models.command("delete") @click.argument("model-id") @click.pass_context def delete_model(ctx: click.Context, model_id: str) -> None: """Delete a model from the proxy""" client = create_client(ctx) result = client.models.delete(model_id=model_id) rich.print_json(data=result) @models.command("get") @click.option("--id", "model_id", help="ID of the model to retrieve") @click.option("--name", "model_name", help="Name of the model to retrieve") @click.pass_context def get_model(ctx: click.Context, model_id: Optional[str], model_name: Optional[str]) -> None: """Get information about a specific model""" if not model_id and not model_name: raise click.UsageError("Either --id or --name must be provided") client = create_client(ctx) result = client.models.get(model_id=model_id, model_name=model_name) rich.print_json(data=result) @models.command("info") @click.option( "--format", "output_format", type=click.Choice(["table", "json"]), default="table", help="Output format (table or json)", ) @click.option( "--columns", "columns", default="public_model,upstream_model,updated_at", help="Comma-separated list of columns to display. Valid columns: public_model, upstream_model, credential_name, created_at, updated_at, id, input_cost, output_cost. Default: public_model,upstream_model,updated_at", ) @click.pass_context def get_models_info(ctx: click.Context, output_format: Literal["table", "json"], columns: str) -> None: """Get detailed information about all models""" client = create_client(ctx) models_info = client.models.info() assert isinstance(models_info, list) if output_format == "json": rich.print_json(data=models_info) else: # table format table = rich.table.Table(title="Models Information") # Define all possible columns with their configurations column_configs: dict[str, dict[str, Any]] = { "public_model": { "header": "Public Model", "style": "cyan", "get_value": lambda m: str(m.get("model_name", "")), }, "upstream_model": { "header": "Upstream Model", "style": "green", "get_value": lambda m: str(m.get("litellm_params", {}).get("model", "")), }, "credential_name": { "header": "Credential Name", "style": "yellow", "get_value": lambda m: str(m.get("litellm_params", {}).get("litellm_credential_name", "")), }, "created_at": { "header": "Created At", "style": "magenta", "get_value": lambda m: format_iso_datetime_str(m.get("model_info", {}).get("created_at")), }, "updated_at": { "header": "Updated At", "style": "magenta", "get_value": lambda m: format_iso_datetime_str(m.get("model_info", {}).get("updated_at")), }, "id": { "header": "ID", "style": "blue", "get_value": lambda m: str(m.get("model_info", {}).get("id", "")), }, "input_cost": { "header": "Input Cost", "style": "green", "justify": "right", "get_value": lambda m: format_cost_per_1k_tokens(m.get("model_info", {}).get("input_cost_per_token")), }, "output_cost": { "header": "Output Cost", "style": "green", "justify": "right", "get_value": lambda m: format_cost_per_1k_tokens(m.get("model_info", {}).get("output_cost_per_token")), }, } # Add requested columns requested_columns = [col.strip() for col in columns.split(",")] for col_name in requested_columns: if col_name in column_configs: config = column_configs[col_name] table.add_column(config["header"], style=config["style"], justify=config.get("justify", "left")) else: click.echo(f"Warning: Unknown column '{col_name}'", err=True) # Add rows with only the requested columns for model in models_info: row_values = [] for col_name in requested_columns: if col_name in column_configs: row_values.append(column_configs[col_name]["get_value"](model)) if row_values: table.add_row(*row_values) rich.print(table) @models.command("update") @click.argument("model-id") @click.option( "--param", "-p", multiple=True, help="Model parameters in key=value format (can be specified multiple times)", ) @click.option( "--info", "-i", multiple=True, help="Model info in key=value format (can be specified multiple times)", ) @click.pass_context def update_model(ctx: click.Context, model_id: str, param: tuple[str, ...], info: tuple[str, ...]) -> None: """Update an existing model's configuration""" # Convert parameters from key=value format to dict model_params = dict(p.split("=", 1) for p in param) model_info = dict(i.split("=", 1) for i in info) if info else None client = create_client(ctx) result = client.models.update( model_id=model_id, model_params=model_params, model_info=model_info, ) rich.print_json(data=result) def _filter_model(model, model_regex, access_group_regex): model_name = model.get("model_name") model_params = model.get("litellm_params") model_info = model.get("model_info", {}) if not model_name or not model_params: return False model_id = model_params.get("model") if not model_id or not isinstance(model_id, str): return False if model_regex and not model_regex.search(model_id): return False access_groups = model_info.get("access_groups", []) if access_group_regex: if not isinstance(access_groups, list): return False if not any(isinstance(group, str) and access_group_regex.search(group) for group in access_groups): return False return True def _print_models_table(added_models: list[ModelYamlInfo], table_title: str): if not added_models: return table = rich.table.Table(title=table_title) table.add_column("Model Name", style="cyan") table.add_column("Upstream Model", style="green") table.add_column("Access Groups", style="magenta") for m in added_models: table.add_row(m.model_name, m.model_id, m.access_groups_str) rich.print(table) def _print_summary_table(provider_counts): summary_table = rich.table.Table(title="Model Import Summary") summary_table.add_column("Provider", style="cyan") summary_table.add_column("Count", style="green") for provider, count in provider_counts.items(): summary_table.add_row(str(provider), str(count)) total = sum(provider_counts.values()) summary_table.add_row("[bold]Total[/bold]", f"[bold]{total}[/bold]") rich.print(summary_table) def get_model_list_from_yaml_file(yaml_file: str) -> list[dict[str, Any]]: """Load and validate the model list from a YAML file.""" with open(yaml_file, "r") as f: data = yaml.safe_load(f) if not data or "model_list" not in data: raise click.ClickException("YAML file must contain a 'model_list' key with a list of models.") model_list = data["model_list"] if not isinstance(model_list, list): raise click.ClickException("'model_list' must be a list of model definitions.") return model_list def _get_filtered_model_list(model_list, only_models_matching_regex, only_access_groups_matching_regex): """Return a list of models that pass the filter criteria.""" model_regex = re.compile(only_models_matching_regex) if only_models_matching_regex else None access_group_regex = re.compile(only_access_groups_matching_regex) if only_access_groups_matching_regex else None return [model for model in model_list if _filter_model(model, model_regex, access_group_regex)] def _import_models_get_table_title(dry_run: bool) -> str: if dry_run: return "Models that would be imported if [yellow]--dry-run[/yellow] was not provided" else: return "Models Imported" @models.command("import") @click.argument("yaml_file", type=click.Path(exists=True, dir_okay=False, readable=True)) @click.option("--dry-run", is_flag=True, help="Show what would be imported without making any changes.") @click.option( "--only-models-matching-regex", default=None, help="Only import models where litellm_params.model matches the given regex.", ) @click.option( "--only-access-groups-matching-regex", default=None, help="Only import models where at least one item in model_info.access_groups matches the given regex.", ) @click.pass_context def import_models( ctx: click.Context, yaml_file: str, dry_run: bool, only_models_matching_regex: Optional[str], only_access_groups_matching_regex: Optional[str], ) -> None: """Import models from a YAML file and add them to the proxy.""" provider_counts: dict[str, int] = defaultdict(int) added_models: list[ModelYamlInfo] = [] model_list = get_model_list_from_yaml_file(yaml_file) filtered_model_list = _get_filtered_model_list( model_list, only_models_matching_regex, only_access_groups_matching_regex ) if not dry_run: client = create_client(ctx) for model in filtered_model_list: model_info_obj = _get_model_info_obj_from_yaml(model) if not dry_run: try: client.models.new( model_name=model_info_obj.model_name, model_params=model_info_obj.model_params, model_info=model_info_obj.model_info, ) except Exception: pass # For summary, ignore errors added_models.append(model_info_obj) provider_counts[model_info_obj.provider] += 1 table_title = _import_models_get_table_title(dry_run) _print_models_table(added_models, table_title) _print_summary_table(provider_counts)